]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

add to CHANGELOG
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 import logging
4 from concurrent.futures import ThreadPoolExecutor
5 from contextlib import contextmanager, redirect_stderr
6 from functools import partial, wraps
7 from io import BytesIO, TextIOWrapper
8 import os
9 from pathlib import Path
10 import re
11 import sys
12 from tempfile import TemporaryDirectory
13 from typing import (
14     Any,
15     BinaryIO,
16     Callable,
17     Coroutine,
18     Generator,
19     List,
20     Tuple,
21     Iterator,
22     TypeVar,
23 )
24 import unittest
25 from unittest.mock import patch, MagicMock
26
27 from click import unstyle
28 from click.testing import CliRunner
29
30 import black
31 from black import Feature, TargetVersion
32
33 try:
34     import blackd
35     from aiohttp.test_utils import TestClient, TestServer
36 except ImportError:
37     has_blackd_deps = False
38 else:
39     has_blackd_deps = True
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 PY36_ARGS = [
47     f"--target-version={version.name.lower()}" for version in black.PY36_VERSIONS
48 ]
49 T = TypeVar("T")
50 R = TypeVar("R")
51
52
53 def dump_to_stderr(*output: str) -> str:
54     return "\n" + "\n".join(output) + "\n"
55
56
57 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
58     """read_data('test_name') -> 'input', 'output'"""
59     if not name.endswith((".py", ".pyi", ".out", ".diff")):
60         name += ".py"
61     _input: List[str] = []
62     _output: List[str] = []
63     base_dir = THIS_DIR / "data" if data else THIS_DIR
64     with open(base_dir / name, "r", encoding="utf8") as test:
65         lines = test.readlines()
66     result = _input
67     for line in lines:
68         line = line.replace(EMPTY_LINE, "")
69         if line.rstrip() == "# output":
70             result = _output
71             continue
72
73         result.append(line)
74     if _input and not _output:
75         # If there's no output marker, treat the entire file as already pre-formatted.
76         _output = _input[:]
77     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
78
79
80 @contextmanager
81 def cache_dir(exists: bool = True) -> Iterator[Path]:
82     with TemporaryDirectory() as workspace:
83         cache_dir = Path(workspace)
84         if not exists:
85             cache_dir = cache_dir / "new"
86         with patch("black.CACHE_DIR", cache_dir):
87             yield cache_dir
88
89
90 @contextmanager
91 def event_loop(close: bool) -> Iterator[None]:
92     policy = asyncio.get_event_loop_policy()
93     old_loop = policy.get_event_loop()
94     loop = policy.new_event_loop()
95     asyncio.set_event_loop(loop)
96     try:
97         yield
98
99     finally:
100         policy.set_event_loop(old_loop)
101         if close:
102             loop.close()
103
104
105 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
106     @event_loop(close=True)
107     @wraps(f)
108     def wrapper(*args: Any, **kwargs: Any) -> None:
109         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
110
111     return wrapper
112
113
114 class BlackRunner(CliRunner):
115     """Modify CliRunner so that stderr is not merged with stdout.
116
117     This is a hack that can be removed once we depend on Click 7.x"""
118
119     def __init__(self) -> None:
120         self.stderrbuf = BytesIO()
121         self.stdoutbuf = BytesIO()
122         self.stdout_bytes = b""
123         self.stderr_bytes = b""
124         super().__init__()
125
126     @contextmanager
127     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
128         with super().isolation(*args, **kwargs) as output:
129             try:
130                 hold_stderr = sys.stderr
131                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
132                 yield output
133             finally:
134                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
135                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
136                 sys.stderr = hold_stderr
137
138
139 class BlackTestCase(unittest.TestCase):
140     maxDiff = None
141
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
144             bdv: black.DebugVisitor[Any]
145             black.out("Expected tree:", fg="green")
146             try:
147                 exp_node = black.lib2to3_parse(expected)
148                 bdv = black.DebugVisitor()
149                 list(bdv.visit(exp_node))
150             except Exception as ve:
151                 black.err(str(ve))
152             black.out("Actual tree:", fg="red")
153             try:
154                 exp_node = black.lib2to3_parse(actual)
155                 bdv = black.DebugVisitor()
156                 list(bdv.visit(exp_node))
157             except Exception as ve:
158                 black.err(str(ve))
159         self.assertEqual(expected, actual)
160
161     def invokeBlack(
162         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
163     ) -> None:
164         runner = BlackRunner()
165         if ignore_config:
166             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
167         result = runner.invoke(black.main, args)
168         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_empty(self) -> None:
172         source = expected = ""
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, black.FileMode())
177
178     def test_empty_ff(self) -> None:
179         expected = ""
180         tmp_file = Path(black.dump_to_file())
181         try:
182             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188
189     @patch("black.dump_to_file", dump_to_stderr)
190     def test_self(self) -> None:
191         source, expected = read_data("test_black", data=False)
192         actual = fs(source)
193         self.assertFormatEqual(expected, actual)
194         black.assert_equivalent(source, actual)
195         black.assert_stable(source, actual, black.FileMode())
196         self.assertFalse(ff(THIS_FILE))
197
198     @patch("black.dump_to_file", dump_to_stderr)
199     def test_black(self) -> None:
200         source, expected = read_data("../black", data=False)
201         actual = fs(source)
202         self.assertFormatEqual(expected, actual)
203         black.assert_equivalent(source, actual)
204         black.assert_stable(source, actual, black.FileMode())
205         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
206
207     def test_piping(self) -> None:
208         source, expected = read_data("../black", data=False)
209         result = BlackRunner().invoke(
210             black.main,
211             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
212             input=BytesIO(source.encode("utf8")),
213         )
214         self.assertEqual(result.exit_code, 0)
215         self.assertFormatEqual(expected, result.output)
216         black.assert_equivalent(source, result.output)
217         black.assert_stable(source, result.output, black.FileMode())
218
219     def test_piping_diff(self) -> None:
220         diff_header = re.compile(
221             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
222             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
223         )
224         source, _ = read_data("expression.py")
225         expected, _ = read_data("expression.diff")
226         config = THIS_DIR / "data" / "empty_pyproject.toml"
227         args = [
228             "-",
229             "--fast",
230             f"--line-length={black.DEFAULT_LINE_LENGTH}",
231             "--diff",
232             f"--config={config}",
233         ]
234         result = BlackRunner().invoke(
235             black.main, args, input=BytesIO(source.encode("utf8"))
236         )
237         self.assertEqual(result.exit_code, 0)
238         actual = diff_header.sub("[Deterministic header]", result.output)
239         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
240         self.assertEqual(expected, actual)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_setup(self) -> None:
244         source, expected = read_data("../setup", data=False)
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, black.FileMode())
249         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_function(self) -> None:
253         source, expected = read_data("function")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, black.FileMode())
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_function2(self) -> None:
261         source, expected = read_data("function2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, black.FileMode())
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_expression(self) -> None:
269         source, expected = read_data("expression")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, black.FileMode())
274
275     def test_expression_ff(self) -> None:
276         source, expected = read_data("expression")
277         tmp_file = Path(black.dump_to_file(source))
278         try:
279             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
280             with open(tmp_file, encoding="utf8") as f:
281                 actual = f.read()
282         finally:
283             os.unlink(tmp_file)
284         self.assertFormatEqual(expected, actual)
285         with patch("black.dump_to_file", dump_to_stderr):
286             black.assert_equivalent(source, actual)
287             black.assert_stable(source, actual, black.FileMode())
288
289     def test_expression_diff(self) -> None:
290         source, _ = read_data("expression.py")
291         expected, _ = read_data("expression.diff")
292         tmp_file = Path(black.dump_to_file(source))
293         diff_header = re.compile(
294             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
295             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
296         )
297         try:
298             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
299             self.assertEqual(result.exit_code, 0)
300         finally:
301             os.unlink(tmp_file)
302         actual = result.output
303         actual = diff_header.sub("[Deterministic header]", actual)
304         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
305         if expected != actual:
306             dump = black.dump_to_file(actual)
307             msg = (
308                 f"Expected diff isn't equal to the actual. If you made changes "
309                 f"to expression.py and this is an anticipated difference, "
310                 f"overwrite tests/data/expression.diff with {dump}"
311             )
312             self.assertEqual(expected, actual, msg)
313
314     @patch("black.dump_to_file", dump_to_stderr)
315     def test_fstring(self) -> None:
316         source, expected = read_data("fstring")
317         actual = fs(source)
318         self.assertFormatEqual(expected, actual)
319         black.assert_equivalent(source, actual)
320         black.assert_stable(source, actual, black.FileMode())
321
322     @patch("black.dump_to_file", dump_to_stderr)
323     def test_string_quotes(self) -> None:
324         source, expected = read_data("string_quotes")
325         actual = fs(source)
326         self.assertFormatEqual(expected, actual)
327         black.assert_equivalent(source, actual)
328         black.assert_stable(source, actual, black.FileMode())
329         mode = black.FileMode(string_normalization=False)
330         not_normalized = fs(source, mode=mode)
331         self.assertFormatEqual(source, not_normalized)
332         black.assert_equivalent(source, not_normalized)
333         black.assert_stable(source, not_normalized, mode=mode)
334
335     @patch("black.dump_to_file", dump_to_stderr)
336     def test_slices(self) -> None:
337         source, expected = read_data("slices")
338         actual = fs(source)
339         self.assertFormatEqual(expected, actual)
340         black.assert_equivalent(source, actual)
341         black.assert_stable(source, actual, black.FileMode())
342
343     @patch("black.dump_to_file", dump_to_stderr)
344     def test_comments(self) -> None:
345         source, expected = read_data("comments")
346         actual = fs(source)
347         self.assertFormatEqual(expected, actual)
348         black.assert_equivalent(source, actual)
349         black.assert_stable(source, actual, black.FileMode())
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_comments2(self) -> None:
353         source, expected = read_data("comments2")
354         actual = fs(source)
355         self.assertFormatEqual(expected, actual)
356         black.assert_equivalent(source, actual)
357         black.assert_stable(source, actual, black.FileMode())
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_comments3(self) -> None:
361         source, expected = read_data("comments3")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, black.FileMode())
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_comments4(self) -> None:
369         source, expected = read_data("comments4")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, black.FileMode())
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_comments5(self) -> None:
377         source, expected = read_data("comments5")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, black.FileMode())
382
383     @patch("black.dump_to_file", dump_to_stderr)
384     def test_comments6(self) -> None:
385         source, expected = read_data("comments6")
386         actual = fs(source)
387         self.assertFormatEqual(expected, actual)
388         black.assert_equivalent(source, actual)
389         black.assert_stable(source, actual, black.FileMode())
390
391     @patch("black.dump_to_file", dump_to_stderr)
392     def test_comments7(self) -> None:
393         source, expected = read_data("comments7")
394         actual = fs(source)
395         self.assertFormatEqual(expected, actual)
396         black.assert_equivalent(source, actual)
397         black.assert_stable(source, actual, black.FileMode())
398
399     @patch("black.dump_to_file", dump_to_stderr)
400     def test_cantfit(self) -> None:
401         source, expected = read_data("cantfit")
402         actual = fs(source)
403         self.assertFormatEqual(expected, actual)
404         black.assert_equivalent(source, actual)
405         black.assert_stable(source, actual, black.FileMode())
406
407     @patch("black.dump_to_file", dump_to_stderr)
408     def test_import_spacing(self) -> None:
409         source, expected = read_data("import_spacing")
410         actual = fs(source)
411         self.assertFormatEqual(expected, actual)
412         black.assert_equivalent(source, actual)
413         black.assert_stable(source, actual, black.FileMode())
414
415     @patch("black.dump_to_file", dump_to_stderr)
416     def test_composition(self) -> None:
417         source, expected = read_data("composition")
418         actual = fs(source)
419         self.assertFormatEqual(expected, actual)
420         black.assert_equivalent(source, actual)
421         black.assert_stable(source, actual, black.FileMode())
422
423     @patch("black.dump_to_file", dump_to_stderr)
424     def test_empty_lines(self) -> None:
425         source, expected = read_data("empty_lines")
426         actual = fs(source)
427         self.assertFormatEqual(expected, actual)
428         black.assert_equivalent(source, actual)
429         black.assert_stable(source, actual, black.FileMode())
430
431     @patch("black.dump_to_file", dump_to_stderr)
432     def test_string_prefixes(self) -> None:
433         source, expected = read_data("string_prefixes")
434         actual = fs(source)
435         self.assertFormatEqual(expected, actual)
436         black.assert_equivalent(source, actual)
437         black.assert_stable(source, actual, black.FileMode())
438
439     @patch("black.dump_to_file", dump_to_stderr)
440     def test_numeric_literals(self) -> None:
441         source, expected = read_data("numeric_literals")
442         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
443         actual = fs(source, mode=mode)
444         self.assertFormatEqual(expected, actual)
445         black.assert_equivalent(source, actual)
446         black.assert_stable(source, actual, mode)
447
448     @patch("black.dump_to_file", dump_to_stderr)
449     def test_numeric_literals_ignoring_underscores(self) -> None:
450         source, expected = read_data("numeric_literals_skip_underscores")
451         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
452         actual = fs(source, mode=mode)
453         self.assertFormatEqual(expected, actual)
454         black.assert_equivalent(source, actual)
455         black.assert_stable(source, actual, mode)
456
457     @patch("black.dump_to_file", dump_to_stderr)
458     def test_numeric_literals_py2(self) -> None:
459         source, expected = read_data("numeric_literals_py2")
460         actual = fs(source)
461         self.assertFormatEqual(expected, actual)
462         black.assert_stable(source, actual, black.FileMode())
463
464     @patch("black.dump_to_file", dump_to_stderr)
465     def test_python2(self) -> None:
466         source, expected = read_data("python2")
467         actual = fs(source)
468         self.assertFormatEqual(expected, actual)
469         # black.assert_equivalent(source, actual)
470         black.assert_stable(source, actual, black.FileMode())
471
472     @patch("black.dump_to_file", dump_to_stderr)
473     def test_python2_print_function(self) -> None:
474         source, expected = read_data("python2_print_function")
475         mode = black.FileMode(target_versions={TargetVersion.PY27})
476         actual = fs(source, mode=mode)
477         self.assertFormatEqual(expected, actual)
478         black.assert_stable(source, actual, mode)
479
480     @patch("black.dump_to_file", dump_to_stderr)
481     def test_python2_unicode_literals(self) -> None:
482         source, expected = read_data("python2_unicode_literals")
483         actual = fs(source)
484         self.assertFormatEqual(expected, actual)
485         black.assert_stable(source, actual, black.FileMode())
486
487     @patch("black.dump_to_file", dump_to_stderr)
488     def test_stub(self) -> None:
489         mode = black.FileMode(is_pyi=True)
490         source, expected = read_data("stub.pyi")
491         actual = fs(source, mode=mode)
492         self.assertFormatEqual(expected, actual)
493         black.assert_stable(source, actual, mode)
494
495     @patch("black.dump_to_file", dump_to_stderr)
496     def test_python37(self) -> None:
497         source, expected = read_data("python37")
498         actual = fs(source)
499         self.assertFormatEqual(expected, actual)
500         major, minor = sys.version_info[:2]
501         if major > 3 or (major == 3 and minor >= 7):
502             black.assert_equivalent(source, actual)
503         black.assert_stable(source, actual, black.FileMode())
504
505     @patch("black.dump_to_file", dump_to_stderr)
506     def test_fmtonoff(self) -> None:
507         source, expected = read_data("fmtonoff")
508         actual = fs(source)
509         self.assertFormatEqual(expected, actual)
510         black.assert_equivalent(source, actual)
511         black.assert_stable(source, actual, black.FileMode())
512
513     @patch("black.dump_to_file", dump_to_stderr)
514     def test_fmtonoff2(self) -> None:
515         source, expected = read_data("fmtonoff2")
516         actual = fs(source)
517         self.assertFormatEqual(expected, actual)
518         black.assert_equivalent(source, actual)
519         black.assert_stable(source, actual, black.FileMode())
520
521     @patch("black.dump_to_file", dump_to_stderr)
522     def test_remove_empty_parentheses_after_class(self) -> None:
523         source, expected = read_data("class_blank_parentheses")
524         actual = fs(source)
525         self.assertFormatEqual(expected, actual)
526         black.assert_equivalent(source, actual)
527         black.assert_stable(source, actual, black.FileMode())
528
529     @patch("black.dump_to_file", dump_to_stderr)
530     def test_new_line_between_class_and_code(self) -> None:
531         source, expected = read_data("class_methods_new_line")
532         actual = fs(source)
533         self.assertFormatEqual(expected, actual)
534         black.assert_equivalent(source, actual)
535         black.assert_stable(source, actual, black.FileMode())
536
537     @patch("black.dump_to_file", dump_to_stderr)
538     def test_bracket_match(self) -> None:
539         source, expected = read_data("bracketmatch")
540         actual = fs(source)
541         self.assertFormatEqual(expected, actual)
542         black.assert_equivalent(source, actual)
543         black.assert_stable(source, actual, black.FileMode())
544
545     def test_tab_comment_indentation(self) -> None:
546         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
547         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
548         self.assertFormatEqual(contents_spc, fs(contents_spc))
549         self.assertFormatEqual(contents_spc, fs(contents_tab))
550
551         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
552         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
553         self.assertFormatEqual(contents_spc, fs(contents_spc))
554         self.assertFormatEqual(contents_spc, fs(contents_tab))
555
556         # mixed tabs and spaces (valid Python 2 code)
557         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t# comment\n        pass\n"
558         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
559         self.assertFormatEqual(contents_spc, fs(contents_spc))
560         self.assertFormatEqual(contents_spc, fs(contents_tab))
561
562         contents_tab = "if 1:\n        if 2:\n\t\tpass\n\t\t# comment\n        pass\n"
563         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
564         self.assertFormatEqual(contents_spc, fs(contents_spc))
565         self.assertFormatEqual(contents_spc, fs(contents_tab))
566
567     def test_report_verbose(self) -> None:
568         report = black.Report(verbose=True)
569         out_lines = []
570         err_lines = []
571
572         def out(msg: str, **kwargs: Any) -> None:
573             out_lines.append(msg)
574
575         def err(msg: str, **kwargs: Any) -> None:
576             err_lines.append(msg)
577
578         with patch("black.out", out), patch("black.err", err):
579             report.done(Path("f1"), black.Changed.NO)
580             self.assertEqual(len(out_lines), 1)
581             self.assertEqual(len(err_lines), 0)
582             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
583             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
584             self.assertEqual(report.return_code, 0)
585             report.done(Path("f2"), black.Changed.YES)
586             self.assertEqual(len(out_lines), 2)
587             self.assertEqual(len(err_lines), 0)
588             self.assertEqual(out_lines[-1], "reformatted f2")
589             self.assertEqual(
590                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
591             )
592             report.done(Path("f3"), black.Changed.CACHED)
593             self.assertEqual(len(out_lines), 3)
594             self.assertEqual(len(err_lines), 0)
595             self.assertEqual(
596                 out_lines[-1], "f3 wasn't modified on disk since last run."
597             )
598             self.assertEqual(
599                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
600             )
601             self.assertEqual(report.return_code, 0)
602             report.check = True
603             self.assertEqual(report.return_code, 1)
604             report.check = False
605             report.failed(Path("e1"), "boom")
606             self.assertEqual(len(out_lines), 3)
607             self.assertEqual(len(err_lines), 1)
608             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
609             self.assertEqual(
610                 unstyle(str(report)),
611                 "1 file reformatted, 2 files left unchanged, "
612                 "1 file failed to reformat.",
613             )
614             self.assertEqual(report.return_code, 123)
615             report.done(Path("f3"), black.Changed.YES)
616             self.assertEqual(len(out_lines), 4)
617             self.assertEqual(len(err_lines), 1)
618             self.assertEqual(out_lines[-1], "reformatted f3")
619             self.assertEqual(
620                 unstyle(str(report)),
621                 "2 files reformatted, 2 files left unchanged, "
622                 "1 file failed to reformat.",
623             )
624             self.assertEqual(report.return_code, 123)
625             report.failed(Path("e2"), "boom")
626             self.assertEqual(len(out_lines), 4)
627             self.assertEqual(len(err_lines), 2)
628             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
629             self.assertEqual(
630                 unstyle(str(report)),
631                 "2 files reformatted, 2 files left unchanged, "
632                 "2 files failed to reformat.",
633             )
634             self.assertEqual(report.return_code, 123)
635             report.path_ignored(Path("wat"), "no match")
636             self.assertEqual(len(out_lines), 5)
637             self.assertEqual(len(err_lines), 2)
638             self.assertEqual(out_lines[-1], "wat ignored: no match")
639             self.assertEqual(
640                 unstyle(str(report)),
641                 "2 files reformatted, 2 files left unchanged, "
642                 "2 files failed to reformat.",
643             )
644             self.assertEqual(report.return_code, 123)
645             report.done(Path("f4"), black.Changed.NO)
646             self.assertEqual(len(out_lines), 6)
647             self.assertEqual(len(err_lines), 2)
648             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
649             self.assertEqual(
650                 unstyle(str(report)),
651                 "2 files reformatted, 3 files left unchanged, "
652                 "2 files failed to reformat.",
653             )
654             self.assertEqual(report.return_code, 123)
655             report.check = True
656             self.assertEqual(
657                 unstyle(str(report)),
658                 "2 files would be reformatted, 3 files would be left unchanged, "
659                 "2 files would fail to reformat.",
660             )
661
662     def test_report_quiet(self) -> None:
663         report = black.Report(quiet=True)
664         out_lines = []
665         err_lines = []
666
667         def out(msg: str, **kwargs: Any) -> None:
668             out_lines.append(msg)
669
670         def err(msg: str, **kwargs: Any) -> None:
671             err_lines.append(msg)
672
673         with patch("black.out", out), patch("black.err", err):
674             report.done(Path("f1"), black.Changed.NO)
675             self.assertEqual(len(out_lines), 0)
676             self.assertEqual(len(err_lines), 0)
677             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
678             self.assertEqual(report.return_code, 0)
679             report.done(Path("f2"), black.Changed.YES)
680             self.assertEqual(len(out_lines), 0)
681             self.assertEqual(len(err_lines), 0)
682             self.assertEqual(
683                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
684             )
685             report.done(Path("f3"), black.Changed.CACHED)
686             self.assertEqual(len(out_lines), 0)
687             self.assertEqual(len(err_lines), 0)
688             self.assertEqual(
689                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
690             )
691             self.assertEqual(report.return_code, 0)
692             report.check = True
693             self.assertEqual(report.return_code, 1)
694             report.check = False
695             report.failed(Path("e1"), "boom")
696             self.assertEqual(len(out_lines), 0)
697             self.assertEqual(len(err_lines), 1)
698             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
699             self.assertEqual(
700                 unstyle(str(report)),
701                 "1 file reformatted, 2 files left unchanged, "
702                 "1 file failed to reformat.",
703             )
704             self.assertEqual(report.return_code, 123)
705             report.done(Path("f3"), black.Changed.YES)
706             self.assertEqual(len(out_lines), 0)
707             self.assertEqual(len(err_lines), 1)
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files reformatted, 2 files left unchanged, "
711                 "1 file failed to reformat.",
712             )
713             self.assertEqual(report.return_code, 123)
714             report.failed(Path("e2"), "boom")
715             self.assertEqual(len(out_lines), 0)
716             self.assertEqual(len(err_lines), 2)
717             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
718             self.assertEqual(
719                 unstyle(str(report)),
720                 "2 files reformatted, 2 files left unchanged, "
721                 "2 files failed to reformat.",
722             )
723             self.assertEqual(report.return_code, 123)
724             report.path_ignored(Path("wat"), "no match")
725             self.assertEqual(len(out_lines), 0)
726             self.assertEqual(len(err_lines), 2)
727             self.assertEqual(
728                 unstyle(str(report)),
729                 "2 files reformatted, 2 files left unchanged, "
730                 "2 files failed to reformat.",
731             )
732             self.assertEqual(report.return_code, 123)
733             report.done(Path("f4"), black.Changed.NO)
734             self.assertEqual(len(out_lines), 0)
735             self.assertEqual(len(err_lines), 2)
736             self.assertEqual(
737                 unstyle(str(report)),
738                 "2 files reformatted, 3 files left unchanged, "
739                 "2 files failed to reformat.",
740             )
741             self.assertEqual(report.return_code, 123)
742             report.check = True
743             self.assertEqual(
744                 unstyle(str(report)),
745                 "2 files would be reformatted, 3 files would be left unchanged, "
746                 "2 files would fail to reformat.",
747             )
748
749     def test_report_normal(self) -> None:
750         report = black.Report()
751         out_lines = []
752         err_lines = []
753
754         def out(msg: str, **kwargs: Any) -> None:
755             out_lines.append(msg)
756
757         def err(msg: str, **kwargs: Any) -> None:
758             err_lines.append(msg)
759
760         with patch("black.out", out), patch("black.err", err):
761             report.done(Path("f1"), black.Changed.NO)
762             self.assertEqual(len(out_lines), 0)
763             self.assertEqual(len(err_lines), 0)
764             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
765             self.assertEqual(report.return_code, 0)
766             report.done(Path("f2"), black.Changed.YES)
767             self.assertEqual(len(out_lines), 1)
768             self.assertEqual(len(err_lines), 0)
769             self.assertEqual(out_lines[-1], "reformatted f2")
770             self.assertEqual(
771                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
772             )
773             report.done(Path("f3"), black.Changed.CACHED)
774             self.assertEqual(len(out_lines), 1)
775             self.assertEqual(len(err_lines), 0)
776             self.assertEqual(out_lines[-1], "reformatted f2")
777             self.assertEqual(
778                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
779             )
780             self.assertEqual(report.return_code, 0)
781             report.check = True
782             self.assertEqual(report.return_code, 1)
783             report.check = False
784             report.failed(Path("e1"), "boom")
785             self.assertEqual(len(out_lines), 1)
786             self.assertEqual(len(err_lines), 1)
787             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
788             self.assertEqual(
789                 unstyle(str(report)),
790                 "1 file reformatted, 2 files left unchanged, "
791                 "1 file failed to reformat.",
792             )
793             self.assertEqual(report.return_code, 123)
794             report.done(Path("f3"), black.Changed.YES)
795             self.assertEqual(len(out_lines), 2)
796             self.assertEqual(len(err_lines), 1)
797             self.assertEqual(out_lines[-1], "reformatted f3")
798             self.assertEqual(
799                 unstyle(str(report)),
800                 "2 files reformatted, 2 files left unchanged, "
801                 "1 file failed to reformat.",
802             )
803             self.assertEqual(report.return_code, 123)
804             report.failed(Path("e2"), "boom")
805             self.assertEqual(len(out_lines), 2)
806             self.assertEqual(len(err_lines), 2)
807             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
808             self.assertEqual(
809                 unstyle(str(report)),
810                 "2 files reformatted, 2 files left unchanged, "
811                 "2 files failed to reformat.",
812             )
813             self.assertEqual(report.return_code, 123)
814             report.path_ignored(Path("wat"), "no match")
815             self.assertEqual(len(out_lines), 2)
816             self.assertEqual(len(err_lines), 2)
817             self.assertEqual(
818                 unstyle(str(report)),
819                 "2 files reformatted, 2 files left unchanged, "
820                 "2 files failed to reformat.",
821             )
822             self.assertEqual(report.return_code, 123)
823             report.done(Path("f4"), black.Changed.NO)
824             self.assertEqual(len(out_lines), 2)
825             self.assertEqual(len(err_lines), 2)
826             self.assertEqual(
827                 unstyle(str(report)),
828                 "2 files reformatted, 3 files left unchanged, "
829                 "2 files failed to reformat.",
830             )
831             self.assertEqual(report.return_code, 123)
832             report.check = True
833             self.assertEqual(
834                 unstyle(str(report)),
835                 "2 files would be reformatted, 3 files would be left unchanged, "
836                 "2 files would fail to reformat.",
837             )
838
839     def test_lib2to3_parse(self) -> None:
840         with self.assertRaises(black.InvalidInput):
841             black.lib2to3_parse("invalid syntax")
842
843         straddling = "x + y"
844         black.lib2to3_parse(straddling)
845         black.lib2to3_parse(straddling, {TargetVersion.PY27})
846         black.lib2to3_parse(straddling, {TargetVersion.PY36})
847         black.lib2to3_parse(straddling, {TargetVersion.PY27, TargetVersion.PY36})
848
849         py2_only = "print x"
850         black.lib2to3_parse(py2_only)
851         black.lib2to3_parse(py2_only, {TargetVersion.PY27})
852         with self.assertRaises(black.InvalidInput):
853             black.lib2to3_parse(py2_only, {TargetVersion.PY36})
854         with self.assertRaises(black.InvalidInput):
855             black.lib2to3_parse(py2_only, {TargetVersion.PY27, TargetVersion.PY36})
856
857         py3_only = "exec(x, end=y)"
858         black.lib2to3_parse(py3_only)
859         with self.assertRaises(black.InvalidInput):
860             black.lib2to3_parse(py3_only, {TargetVersion.PY27})
861         black.lib2to3_parse(py3_only, {TargetVersion.PY36})
862         black.lib2to3_parse(py3_only, {TargetVersion.PY27, TargetVersion.PY36})
863
864     def test_get_features_used(self) -> None:
865         node = black.lib2to3_parse("def f(*, arg): ...\n")
866         self.assertEqual(black.get_features_used(node), set())
867         node = black.lib2to3_parse("def f(*, arg,): ...\n")
868         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA_IN_DEF})
869         node = black.lib2to3_parse("f(*arg,)\n")
870         self.assertEqual(
871             black.get_features_used(node), {Feature.TRAILING_COMMA_IN_CALL}
872         )
873         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
874         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
875         node = black.lib2to3_parse("123_456\n")
876         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
877         node = black.lib2to3_parse("123456\n")
878         self.assertEqual(black.get_features_used(node), set())
879         source, expected = read_data("function")
880         node = black.lib2to3_parse(source)
881         expected_features = {
882             Feature.TRAILING_COMMA_IN_CALL,
883             Feature.TRAILING_COMMA_IN_DEF,
884             Feature.F_STRINGS,
885         }
886         self.assertEqual(black.get_features_used(node), expected_features)
887         node = black.lib2to3_parse(expected)
888         self.assertEqual(black.get_features_used(node), expected_features)
889         source, expected = read_data("expression")
890         node = black.lib2to3_parse(source)
891         self.assertEqual(black.get_features_used(node), set())
892         node = black.lib2to3_parse(expected)
893         self.assertEqual(black.get_features_used(node), set())
894
895     def test_get_future_imports(self) -> None:
896         node = black.lib2to3_parse("\n")
897         self.assertEqual(set(), black.get_future_imports(node))
898         node = black.lib2to3_parse("from __future__ import black\n")
899         self.assertEqual({"black"}, black.get_future_imports(node))
900         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
901         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
902         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
903         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
904         node = black.lib2to3_parse(
905             "from __future__ import multiple\nfrom __future__ import imports\n"
906         )
907         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
908         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
909         self.assertEqual({"black"}, black.get_future_imports(node))
910         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
911         self.assertEqual({"black"}, black.get_future_imports(node))
912         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
913         self.assertEqual(set(), black.get_future_imports(node))
914         node = black.lib2to3_parse("from some.module import black\n")
915         self.assertEqual(set(), black.get_future_imports(node))
916         node = black.lib2to3_parse(
917             "from __future__ import unicode_literals as _unicode_literals"
918         )
919         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
920         node = black.lib2to3_parse(
921             "from __future__ import unicode_literals as _lol, print"
922         )
923         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
924
925     def test_debug_visitor(self) -> None:
926         source, _ = read_data("debug_visitor.py")
927         expected, _ = read_data("debug_visitor.out")
928         out_lines = []
929         err_lines = []
930
931         def out(msg: str, **kwargs: Any) -> None:
932             out_lines.append(msg)
933
934         def err(msg: str, **kwargs: Any) -> None:
935             err_lines.append(msg)
936
937         with patch("black.out", out), patch("black.err", err):
938             black.DebugVisitor.show(source)
939         actual = "\n".join(out_lines) + "\n"
940         log_name = ""
941         if expected != actual:
942             log_name = black.dump_to_file(*out_lines)
943         self.assertEqual(
944             expected,
945             actual,
946             f"AST print out is different. Actual version dumped to {log_name}",
947         )
948
949     def test_format_file_contents(self) -> None:
950         empty = ""
951         mode = black.FileMode()
952         with self.assertRaises(black.NothingChanged):
953             black.format_file_contents(empty, mode=mode, fast=False)
954         just_nl = "\n"
955         with self.assertRaises(black.NothingChanged):
956             black.format_file_contents(just_nl, mode=mode, fast=False)
957         same = "l = [1, 2, 3]\n"
958         with self.assertRaises(black.NothingChanged):
959             black.format_file_contents(same, mode=mode, fast=False)
960         different = "l = [1,2,3]"
961         expected = same
962         actual = black.format_file_contents(different, mode=mode, fast=False)
963         self.assertEqual(expected, actual)
964         invalid = "return if you can"
965         with self.assertRaises(black.InvalidInput) as e:
966             black.format_file_contents(invalid, mode=mode, fast=False)
967         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
968
969     def test_endmarker(self) -> None:
970         n = black.lib2to3_parse("\n")
971         self.assertEqual(n.type, black.syms.file_input)
972         self.assertEqual(len(n.children), 1)
973         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
974
975     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
976     def test_assertFormatEqual(self) -> None:
977         out_lines = []
978         err_lines = []
979
980         def out(msg: str, **kwargs: Any) -> None:
981             out_lines.append(msg)
982
983         def err(msg: str, **kwargs: Any) -> None:
984             err_lines.append(msg)
985
986         with patch("black.out", out), patch("black.err", err):
987             with self.assertRaises(AssertionError):
988                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
989
990         out_str = "".join(out_lines)
991         self.assertTrue("Expected tree:" in out_str)
992         self.assertTrue("Actual tree:" in out_str)
993         self.assertEqual("".join(err_lines), "")
994
995     def test_cache_broken_file(self) -> None:
996         mode = black.FileMode()
997         with cache_dir() as workspace:
998             cache_file = black.get_cache_file(mode)
999             with cache_file.open("w") as fobj:
1000                 fobj.write("this is not a pickle")
1001             self.assertEqual(black.read_cache(mode), {})
1002             src = (workspace / "test.py").resolve()
1003             with src.open("w") as fobj:
1004                 fobj.write("print('hello')")
1005             self.invokeBlack([str(src)])
1006             cache = black.read_cache(mode)
1007             self.assertIn(src, cache)
1008
1009     def test_cache_single_file_already_cached(self) -> None:
1010         mode = black.FileMode()
1011         with cache_dir() as workspace:
1012             src = (workspace / "test.py").resolve()
1013             with src.open("w") as fobj:
1014                 fobj.write("print('hello')")
1015             black.write_cache({}, [src], mode)
1016             self.invokeBlack([str(src)])
1017             with src.open("r") as fobj:
1018                 self.assertEqual(fobj.read(), "print('hello')")
1019
1020     @event_loop(close=False)
1021     def test_cache_multiple_files(self) -> None:
1022         mode = black.FileMode()
1023         with cache_dir() as workspace, patch(
1024             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1025         ):
1026             one = (workspace / "one.py").resolve()
1027             with one.open("w") as fobj:
1028                 fobj.write("print('hello')")
1029             two = (workspace / "two.py").resolve()
1030             with two.open("w") as fobj:
1031                 fobj.write("print('hello')")
1032             black.write_cache({}, [one], mode)
1033             self.invokeBlack([str(workspace)])
1034             with one.open("r") as fobj:
1035                 self.assertEqual(fobj.read(), "print('hello')")
1036             with two.open("r") as fobj:
1037                 self.assertEqual(fobj.read(), 'print("hello")\n')
1038             cache = black.read_cache(mode)
1039             self.assertIn(one, cache)
1040             self.assertIn(two, cache)
1041
1042     def test_no_cache_when_writeback_diff(self) -> None:
1043         mode = black.FileMode()
1044         with cache_dir() as workspace:
1045             src = (workspace / "test.py").resolve()
1046             with src.open("w") as fobj:
1047                 fobj.write("print('hello')")
1048             self.invokeBlack([str(src), "--diff"])
1049             cache_file = black.get_cache_file(mode)
1050             self.assertFalse(cache_file.exists())
1051
1052     def test_no_cache_when_stdin(self) -> None:
1053         mode = black.FileMode()
1054         with cache_dir():
1055             result = CliRunner().invoke(
1056                 black.main, ["-"], input=BytesIO(b"print('hello')")
1057             )
1058             self.assertEqual(result.exit_code, 0)
1059             cache_file = black.get_cache_file(mode)
1060             self.assertFalse(cache_file.exists())
1061
1062     def test_read_cache_no_cachefile(self) -> None:
1063         mode = black.FileMode()
1064         with cache_dir():
1065             self.assertEqual(black.read_cache(mode), {})
1066
1067     def test_write_cache_read_cache(self) -> None:
1068         mode = black.FileMode()
1069         with cache_dir() as workspace:
1070             src = (workspace / "test.py").resolve()
1071             src.touch()
1072             black.write_cache({}, [src], mode)
1073             cache = black.read_cache(mode)
1074             self.assertIn(src, cache)
1075             self.assertEqual(cache[src], black.get_cache_info(src))
1076
1077     def test_filter_cached(self) -> None:
1078         with TemporaryDirectory() as workspace:
1079             path = Path(workspace)
1080             uncached = (path / "uncached").resolve()
1081             cached = (path / "cached").resolve()
1082             cached_but_changed = (path / "changed").resolve()
1083             uncached.touch()
1084             cached.touch()
1085             cached_but_changed.touch()
1086             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1087             todo, done = black.filter_cached(
1088                 cache, {uncached, cached, cached_but_changed}
1089             )
1090             self.assertEqual(todo, {uncached, cached_but_changed})
1091             self.assertEqual(done, {cached})
1092
1093     def test_write_cache_creates_directory_if_needed(self) -> None:
1094         mode = black.FileMode()
1095         with cache_dir(exists=False) as workspace:
1096             self.assertFalse(workspace.exists())
1097             black.write_cache({}, [], mode)
1098             self.assertTrue(workspace.exists())
1099
1100     @event_loop(close=False)
1101     def test_failed_formatting_does_not_get_cached(self) -> None:
1102         mode = black.FileMode()
1103         with cache_dir() as workspace, patch(
1104             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1105         ):
1106             failing = (workspace / "failing.py").resolve()
1107             with failing.open("w") as fobj:
1108                 fobj.write("not actually python")
1109             clean = (workspace / "clean.py").resolve()
1110             with clean.open("w") as fobj:
1111                 fobj.write('print("hello")\n')
1112             self.invokeBlack([str(workspace)], exit_code=123)
1113             cache = black.read_cache(mode)
1114             self.assertNotIn(failing, cache)
1115             self.assertIn(clean, cache)
1116
1117     def test_write_cache_write_fail(self) -> None:
1118         mode = black.FileMode()
1119         with cache_dir(), patch.object(Path, "open") as mock:
1120             mock.side_effect = OSError
1121             black.write_cache({}, [], mode)
1122
1123     @event_loop(close=False)
1124     def test_check_diff_use_together(self) -> None:
1125         with cache_dir():
1126             # Files which will be reformatted.
1127             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1128             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1129             # Files which will not be reformatted.
1130             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1131             self.invokeBlack([str(src2), "--diff", "--check"])
1132             # Multi file command.
1133             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1134
1135     def test_no_files(self) -> None:
1136         with cache_dir():
1137             # Without an argument, black exits with error code 0.
1138             self.invokeBlack([])
1139
1140     def test_broken_symlink(self) -> None:
1141         with cache_dir() as workspace:
1142             symlink = workspace / "broken_link.py"
1143             try:
1144                 symlink.symlink_to("nonexistent.py")
1145             except OSError as e:
1146                 self.skipTest(f"Can't create symlinks: {e}")
1147             self.invokeBlack([str(workspace.resolve())])
1148
1149     def test_read_cache_line_lengths(self) -> None:
1150         mode = black.FileMode()
1151         short_mode = black.FileMode(line_length=1)
1152         with cache_dir() as workspace:
1153             path = (workspace / "file.py").resolve()
1154             path.touch()
1155             black.write_cache({}, [path], mode)
1156             one = black.read_cache(mode)
1157             self.assertIn(path, one)
1158             two = black.read_cache(short_mode)
1159             self.assertNotIn(path, two)
1160
1161     def test_single_file_force_pyi(self) -> None:
1162         reg_mode = black.FileMode()
1163         pyi_mode = black.FileMode(is_pyi=True)
1164         contents, expected = read_data("force_pyi")
1165         with cache_dir() as workspace:
1166             path = (workspace / "file.py").resolve()
1167             with open(path, "w") as fh:
1168                 fh.write(contents)
1169             self.invokeBlack([str(path), "--pyi"])
1170             with open(path, "r") as fh:
1171                 actual = fh.read()
1172             # verify cache with --pyi is separate
1173             pyi_cache = black.read_cache(pyi_mode)
1174             self.assertIn(path, pyi_cache)
1175             normal_cache = black.read_cache(reg_mode)
1176             self.assertNotIn(path, normal_cache)
1177         self.assertEqual(actual, expected)
1178
1179     @event_loop(close=False)
1180     def test_multi_file_force_pyi(self) -> None:
1181         reg_mode = black.FileMode()
1182         pyi_mode = black.FileMode(is_pyi=True)
1183         contents, expected = read_data("force_pyi")
1184         with cache_dir() as workspace:
1185             paths = [
1186                 (workspace / "file1.py").resolve(),
1187                 (workspace / "file2.py").resolve(),
1188             ]
1189             for path in paths:
1190                 with open(path, "w") as fh:
1191                     fh.write(contents)
1192             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1193             for path in paths:
1194                 with open(path, "r") as fh:
1195                     actual = fh.read()
1196                 self.assertEqual(actual, expected)
1197             # verify cache with --pyi is separate
1198             pyi_cache = black.read_cache(pyi_mode)
1199             normal_cache = black.read_cache(reg_mode)
1200             for path in paths:
1201                 self.assertIn(path, pyi_cache)
1202                 self.assertNotIn(path, normal_cache)
1203
1204     def test_pipe_force_pyi(self) -> None:
1205         source, expected = read_data("force_pyi")
1206         result = CliRunner().invoke(
1207             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1208         )
1209         self.assertEqual(result.exit_code, 0)
1210         actual = result.output
1211         self.assertFormatEqual(actual, expected)
1212
1213     def test_single_file_force_py36(self) -> None:
1214         reg_mode = black.FileMode()
1215         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1216         source, expected = read_data("force_py36")
1217         with cache_dir() as workspace:
1218             path = (workspace / "file.py").resolve()
1219             with open(path, "w") as fh:
1220                 fh.write(source)
1221             self.invokeBlack([str(path), *PY36_ARGS])
1222             with open(path, "r") as fh:
1223                 actual = fh.read()
1224             # verify cache with --target-version is separate
1225             py36_cache = black.read_cache(py36_mode)
1226             self.assertIn(path, py36_cache)
1227             normal_cache = black.read_cache(reg_mode)
1228             self.assertNotIn(path, normal_cache)
1229         self.assertEqual(actual, expected)
1230
1231     @event_loop(close=False)
1232     def test_multi_file_force_py36(self) -> None:
1233         reg_mode = black.FileMode()
1234         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1235         source, expected = read_data("force_py36")
1236         with cache_dir() as workspace:
1237             paths = [
1238                 (workspace / "file1.py").resolve(),
1239                 (workspace / "file2.py").resolve(),
1240             ]
1241             for path in paths:
1242                 with open(path, "w") as fh:
1243                     fh.write(source)
1244             self.invokeBlack([str(p) for p in paths] + PY36_ARGS)
1245             for path in paths:
1246                 with open(path, "r") as fh:
1247                     actual = fh.read()
1248                 self.assertEqual(actual, expected)
1249             # verify cache with --target-version is separate
1250             pyi_cache = black.read_cache(py36_mode)
1251             normal_cache = black.read_cache(reg_mode)
1252             for path in paths:
1253                 self.assertIn(path, pyi_cache)
1254                 self.assertNotIn(path, normal_cache)
1255
1256     def test_pipe_force_py36(self) -> None:
1257         source, expected = read_data("force_py36")
1258         result = CliRunner().invoke(
1259             black.main,
1260             ["-", "-q", "--target-version=py36"],
1261             input=BytesIO(source.encode("utf8")),
1262         )
1263         self.assertEqual(result.exit_code, 0)
1264         actual = result.output
1265         self.assertFormatEqual(actual, expected)
1266
1267     def test_include_exclude(self) -> None:
1268         path = THIS_DIR / "data" / "include_exclude_tests"
1269         include = re.compile(r"\.pyi?$")
1270         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1271         report = black.Report()
1272         sources: List[Path] = []
1273         expected = [
1274             Path(path / "b/dont_exclude/a.py"),
1275             Path(path / "b/dont_exclude/a.pyi"),
1276         ]
1277         this_abs = THIS_DIR.resolve()
1278         sources.extend(
1279             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1280         )
1281         self.assertEqual(sorted(expected), sorted(sources))
1282
1283     def test_empty_include(self) -> None:
1284         path = THIS_DIR / "data" / "include_exclude_tests"
1285         report = black.Report()
1286         empty = re.compile(r"")
1287         sources: List[Path] = []
1288         expected = [
1289             Path(path / "b/exclude/a.pie"),
1290             Path(path / "b/exclude/a.py"),
1291             Path(path / "b/exclude/a.pyi"),
1292             Path(path / "b/dont_exclude/a.pie"),
1293             Path(path / "b/dont_exclude/a.py"),
1294             Path(path / "b/dont_exclude/a.pyi"),
1295             Path(path / "b/.definitely_exclude/a.pie"),
1296             Path(path / "b/.definitely_exclude/a.py"),
1297             Path(path / "b/.definitely_exclude/a.pyi"),
1298         ]
1299         this_abs = THIS_DIR.resolve()
1300         sources.extend(
1301             black.gen_python_files_in_dir(
1302                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1303             )
1304         )
1305         self.assertEqual(sorted(expected), sorted(sources))
1306
1307     def test_empty_exclude(self) -> None:
1308         path = THIS_DIR / "data" / "include_exclude_tests"
1309         report = black.Report()
1310         empty = re.compile(r"")
1311         sources: List[Path] = []
1312         expected = [
1313             Path(path / "b/dont_exclude/a.py"),
1314             Path(path / "b/dont_exclude/a.pyi"),
1315             Path(path / "b/exclude/a.py"),
1316             Path(path / "b/exclude/a.pyi"),
1317             Path(path / "b/.definitely_exclude/a.py"),
1318             Path(path / "b/.definitely_exclude/a.pyi"),
1319         ]
1320         this_abs = THIS_DIR.resolve()
1321         sources.extend(
1322             black.gen_python_files_in_dir(
1323                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1324             )
1325         )
1326         self.assertEqual(sorted(expected), sorted(sources))
1327
1328     def test_invalid_include_exclude(self) -> None:
1329         for option in ["--include", "--exclude"]:
1330             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1331
1332     def test_preserves_line_endings(self) -> None:
1333         with TemporaryDirectory() as workspace:
1334             test_file = Path(workspace) / "test.py"
1335             for nl in ["\n", "\r\n"]:
1336                 contents = nl.join(["def f(  ):", "    pass"])
1337                 test_file.write_bytes(contents.encode())
1338                 ff(test_file, write_back=black.WriteBack.YES)
1339                 updated_contents: bytes = test_file.read_bytes()
1340                 self.assertIn(nl.encode(), updated_contents)
1341                 if nl == "\n":
1342                     self.assertNotIn(b"\r\n", updated_contents)
1343
1344     def test_preserves_line_endings_via_stdin(self) -> None:
1345         for nl in ["\n", "\r\n"]:
1346             contents = nl.join(["def f(  ):", "    pass"])
1347             runner = BlackRunner()
1348             result = runner.invoke(
1349                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1350             )
1351             self.assertEqual(result.exit_code, 0)
1352             output = runner.stdout_bytes
1353             self.assertIn(nl.encode("utf8"), output)
1354             if nl == "\n":
1355                 self.assertNotIn(b"\r\n", output)
1356
1357     def test_assert_equivalent_different_asts(self) -> None:
1358         with self.assertRaises(AssertionError):
1359             black.assert_equivalent("{}", "None")
1360
1361     def test_symlink_out_of_root_directory(self) -> None:
1362         path = MagicMock()
1363         root = THIS_DIR
1364         child = MagicMock()
1365         include = re.compile(black.DEFAULT_INCLUDES)
1366         exclude = re.compile(black.DEFAULT_EXCLUDES)
1367         report = black.Report()
1368         # `child` should behave like a symlink which resolved path is clearly
1369         # outside of the `root` directory.
1370         path.iterdir.return_value = [child]
1371         child.resolve.return_value = Path("/a/b/c")
1372         child.is_symlink.return_value = True
1373         try:
1374             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1375         except ValueError as ve:
1376             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1377         path.iterdir.assert_called_once()
1378         child.resolve.assert_called_once()
1379         child.is_symlink.assert_called_once()
1380         # `child` should behave like a strange file which resolved path is clearly
1381         # outside of the `root` directory.
1382         child.is_symlink.return_value = False
1383         with self.assertRaises(ValueError):
1384             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1385         path.iterdir.assert_called()
1386         self.assertEqual(path.iterdir.call_count, 2)
1387         child.resolve.assert_called()
1388         self.assertEqual(child.resolve.call_count, 2)
1389         child.is_symlink.assert_called()
1390         self.assertEqual(child.is_symlink.call_count, 2)
1391
1392     def test_shhh_click(self) -> None:
1393         try:
1394             from click import _unicodefun  # type: ignore
1395         except ModuleNotFoundError:
1396             self.skipTest("Incompatible Click version")
1397         if not hasattr(_unicodefun, "_verify_python3_env"):
1398             self.skipTest("Incompatible Click version")
1399         # First, let's see if Click is crashing with a preferred ASCII charset.
1400         with patch("locale.getpreferredencoding") as gpe:
1401             gpe.return_value = "ASCII"
1402             with self.assertRaises(RuntimeError):
1403                 _unicodefun._verify_python3_env()
1404         # Now, let's silence Click...
1405         black.patch_click()
1406         # ...and confirm it's silent.
1407         with patch("locale.getpreferredencoding") as gpe:
1408             gpe.return_value = "ASCII"
1409             try:
1410                 _unicodefun._verify_python3_env()
1411             except RuntimeError as re:
1412                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1413
1414     def test_root_logger_not_used_directly(self) -> None:
1415         def fail(*args: Any, **kwargs: Any) -> None:
1416             self.fail("Record created with root logger")
1417
1418         with patch.multiple(
1419             logging.root,
1420             debug=fail,
1421             info=fail,
1422             warning=fail,
1423             error=fail,
1424             critical=fail,
1425             log=fail,
1426         ):
1427             ff(THIS_FILE)
1428
1429     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1430     @async_test
1431     async def test_blackd_request_needs_formatting(self) -> None:
1432         app = blackd.make_app()
1433         async with TestClient(TestServer(app)) as client:
1434             response = await client.post("/", data=b"print('hello world')")
1435             self.assertEqual(response.status, 200)
1436             self.assertEqual(response.charset, "utf8")
1437             self.assertEqual(await response.read(), b'print("hello world")\n')
1438
1439     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1440     @async_test
1441     async def test_blackd_request_no_change(self) -> None:
1442         app = blackd.make_app()
1443         async with TestClient(TestServer(app)) as client:
1444             response = await client.post("/", data=b'print("hello world")\n')
1445             self.assertEqual(response.status, 204)
1446             self.assertEqual(await response.read(), b"")
1447
1448     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1449     @async_test
1450     async def test_blackd_request_syntax_error(self) -> None:
1451         app = blackd.make_app()
1452         async with TestClient(TestServer(app)) as client:
1453             response = await client.post("/", data=b"what even ( is")
1454             self.assertEqual(response.status, 400)
1455             content = await response.text()
1456             self.assertTrue(
1457                 content.startswith("Cannot parse"),
1458                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1459             )
1460
1461     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1462     @async_test
1463     async def test_blackd_unsupported_version(self) -> None:
1464         app = blackd.make_app()
1465         async with TestClient(TestServer(app)) as client:
1466             response = await client.post(
1467                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1468             )
1469             self.assertEqual(response.status, 501)
1470
1471     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1472     @async_test
1473     async def test_blackd_supported_version(self) -> None:
1474         app = blackd.make_app()
1475         async with TestClient(TestServer(app)) as client:
1476             response = await client.post(
1477                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1478             )
1479             self.assertEqual(response.status, 200)
1480
1481     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1482     @async_test
1483     async def test_blackd_invalid_python_variant(self) -> None:
1484         app = blackd.make_app()
1485         async with TestClient(TestServer(app)) as client:
1486
1487             async def check(header_value: str, expected_status: int = 400) -> None:
1488                 response = await client.post(
1489                     "/",
1490                     data=b"what",
1491                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1492                 )
1493                 self.assertEqual(response.status, expected_status)
1494
1495             await check("lol")
1496             await check("ruby3.5")
1497             await check("pyi3.6")
1498             await check("py1.5")
1499             await check("2.8")
1500             await check("py2.8")
1501             await check("3.0")
1502             await check("pypy3.0")
1503             await check("jython3.4")
1504
1505     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1506     @async_test
1507     async def test_blackd_pyi(self) -> None:
1508         app = blackd.make_app()
1509         async with TestClient(TestServer(app)) as client:
1510             source, expected = read_data("stub.pyi")
1511             response = await client.post(
1512                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1513             )
1514             self.assertEqual(response.status, 200)
1515             self.assertEqual(await response.text(), expected)
1516
1517     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1518     @async_test
1519     async def test_blackd_python_variant(self) -> None:
1520         app = blackd.make_app()
1521         code = (
1522             "def f(\n"
1523             "    and_has_a_bunch_of,\n"
1524             "    very_long_arguments_too,\n"
1525             "    and_lots_of_them_as_well_lol,\n"
1526             "    **and_very_long_keyword_arguments\n"
1527             "):\n"
1528             "    pass\n"
1529         )
1530         async with TestClient(TestServer(app)) as client:
1531
1532             async def check(header_value: str, expected_status: int) -> None:
1533                 response = await client.post(
1534                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1535                 )
1536                 self.assertEqual(response.status, expected_status)
1537
1538             await check("3.6", 200)
1539             await check("py3.6", 200)
1540             await check("3.6,3.7", 200)
1541             await check("3.6,py3.7", 200)
1542
1543             await check("2", 204)
1544             await check("2.7", 204)
1545             await check("py2.7", 204)
1546             await check("3.4", 204)
1547             await check("py3.4", 204)
1548
1549     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1550     @async_test
1551     async def test_blackd_fast(self) -> None:
1552         with open(os.devnull, "w") as dn, redirect_stderr(dn):
1553             app = blackd.make_app()
1554             async with TestClient(TestServer(app)) as client:
1555                 response = await client.post("/", data=b"ur'hello'")
1556                 self.assertEqual(response.status, 500)
1557                 self.assertIn("failed to parse source file", await response.text())
1558                 response = await client.post(
1559                     "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1560                 )
1561                 self.assertEqual(response.status, 200)
1562
1563     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1564     @async_test
1565     async def test_blackd_line_length(self) -> None:
1566         app = blackd.make_app()
1567         async with TestClient(TestServer(app)) as client:
1568             response = await client.post(
1569                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1570             )
1571             self.assertEqual(response.status, 200)
1572
1573     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1574     @async_test
1575     async def test_blackd_invalid_line_length(self) -> None:
1576         app = blackd.make_app()
1577         async with TestClient(TestServer(app)) as client:
1578             response = await client.post(
1579                 "/",
1580                 data=b'print("hello")\n',
1581                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1582             )
1583             self.assertEqual(response.status, 400)
1584
1585     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1586     def test_blackd_main(self) -> None:
1587         with patch("blackd.web.run_app"):
1588             result = CliRunner().invoke(blackd.main, [])
1589             if result.exception is not None:
1590                 raise result.exception
1591             self.assertEqual(result.exit_code, 0)
1592
1593
1594 if __name__ == "__main__":
1595     unittest.main(module="test_black")