]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add `--target-version` option to allow users to choose targeted Python versions ...
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager, redirect_stderr
5 from functools import partial, wraps
6 from io import BytesIO, TextIOWrapper
7 import os
8 from pathlib import Path
9 import re
10 import sys
11 from tempfile import TemporaryDirectory
12 from typing import (
13     Any,
14     BinaryIO,
15     Callable,
16     Coroutine,
17     Generator,
18     List,
19     Tuple,
20     Iterator,
21     TypeVar,
22 )
23 import unittest
24 from unittest.mock import patch, MagicMock
25
26 from click import unstyle
27 from click.testing import CliRunner
28
29 import black
30 from black import Feature
31
32 try:
33     import blackd
34     from aiohttp.test_utils import TestClient, TestServer
35 except ImportError:
36     has_blackd_deps = False
37 else:
38     has_blackd_deps = True
39
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 T = TypeVar("T")
47 R = TypeVar("R")
48
49
50 def dump_to_stderr(*output: str) -> str:
51     return "\n" + "\n".join(output) + "\n"
52
53
54 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
55     """read_data('test_name') -> 'input', 'output'"""
56     if not name.endswith((".py", ".pyi", ".out", ".diff")):
57         name += ".py"
58     _input: List[str] = []
59     _output: List[str] = []
60     base_dir = THIS_DIR / "data" if data else THIS_DIR
61     with open(base_dir / name, "r", encoding="utf8") as test:
62         lines = test.readlines()
63     result = _input
64     for line in lines:
65         line = line.replace(EMPTY_LINE, "")
66         if line.rstrip() == "# output":
67             result = _output
68             continue
69
70         result.append(line)
71     if _input and not _output:
72         # If there's no output marker, treat the entire file as already pre-formatted.
73         _output = _input[:]
74     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
75
76
77 @contextmanager
78 def cache_dir(exists: bool = True) -> Iterator[Path]:
79     with TemporaryDirectory() as workspace:
80         cache_dir = Path(workspace)
81         if not exists:
82             cache_dir = cache_dir / "new"
83         with patch("black.CACHE_DIR", cache_dir):
84             yield cache_dir
85
86
87 @contextmanager
88 def event_loop(close: bool) -> Iterator[None]:
89     policy = asyncio.get_event_loop_policy()
90     old_loop = policy.get_event_loop()
91     loop = policy.new_event_loop()
92     asyncio.set_event_loop(loop)
93     try:
94         yield
95
96     finally:
97         policy.set_event_loop(old_loop)
98         if close:
99             loop.close()
100
101
102 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
103     @event_loop(close=True)
104     @wraps(f)
105     def wrapper(*args: Any, **kwargs: Any) -> None:
106         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
107
108     return wrapper
109
110
111 class BlackRunner(CliRunner):
112     """Modify CliRunner so that stderr is not merged with stdout.
113
114     This is a hack that can be removed once we depend on Click 7.x"""
115
116     def __init__(self) -> None:
117         self.stderrbuf = BytesIO()
118         self.stdoutbuf = BytesIO()
119         self.stdout_bytes = b""
120         self.stderr_bytes = b""
121         super().__init__()
122
123     @contextmanager
124     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
125         with super().isolation(*args, **kwargs) as output:
126             try:
127                 hold_stderr = sys.stderr
128                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
129                 yield output
130             finally:
131                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
132                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
133                 sys.stderr = hold_stderr
134
135
136 class BlackTestCase(unittest.TestCase):
137     maxDiff = None
138
139     def assertFormatEqual(self, expected: str, actual: str) -> None:
140         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
141             bdv: black.DebugVisitor[Any]
142             black.out("Expected tree:", fg="green")
143             try:
144                 exp_node = black.lib2to3_parse(expected)
145                 bdv = black.DebugVisitor()
146                 list(bdv.visit(exp_node))
147             except Exception as ve:
148                 black.err(str(ve))
149             black.out("Actual tree:", fg="red")
150             try:
151                 exp_node = black.lib2to3_parse(actual)
152                 bdv = black.DebugVisitor()
153                 list(bdv.visit(exp_node))
154             except Exception as ve:
155                 black.err(str(ve))
156         self.assertEqual(expected, actual)
157
158     def invokeBlack(
159         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
160     ) -> None:
161         runner = BlackRunner()
162         if ignore_config:
163             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
164         result = runner.invoke(black.main, args)
165         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
166
167     @patch("black.dump_to_file", dump_to_stderr)
168     def test_empty(self) -> None:
169         source = expected = ""
170         actual = fs(source)
171         self.assertFormatEqual(expected, actual)
172         black.assert_equivalent(source, actual)
173         black.assert_stable(source, actual, black.FileMode())
174
175     def test_empty_ff(self) -> None:
176         expected = ""
177         tmp_file = Path(black.dump_to_file())
178         try:
179             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
180             with open(tmp_file, encoding="utf8") as f:
181                 actual = f.read()
182         finally:
183             os.unlink(tmp_file)
184         self.assertFormatEqual(expected, actual)
185
186     @patch("black.dump_to_file", dump_to_stderr)
187     def test_self(self) -> None:
188         source, expected = read_data("test_black", data=False)
189         actual = fs(source)
190         self.assertFormatEqual(expected, actual)
191         black.assert_equivalent(source, actual)
192         black.assert_stable(source, actual, black.FileMode())
193         self.assertFalse(ff(THIS_FILE))
194
195     @patch("black.dump_to_file", dump_to_stderr)
196     def test_black(self) -> None:
197         source, expected = read_data("../black", data=False)
198         actual = fs(source)
199         self.assertFormatEqual(expected, actual)
200         black.assert_equivalent(source, actual)
201         black.assert_stable(source, actual, black.FileMode())
202         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
203
204     def test_piping(self) -> None:
205         source, expected = read_data("../black", data=False)
206         result = BlackRunner().invoke(
207             black.main,
208             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
209             input=BytesIO(source.encode("utf8")),
210         )
211         self.assertEqual(result.exit_code, 0)
212         self.assertFormatEqual(expected, result.output)
213         black.assert_equivalent(source, result.output)
214         black.assert_stable(source, result.output, black.FileMode())
215
216     def test_piping_diff(self) -> None:
217         diff_header = re.compile(
218             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
219             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
220         )
221         source, _ = read_data("expression.py")
222         expected, _ = read_data("expression.diff")
223         config = THIS_DIR / "data" / "empty_pyproject.toml"
224         args = [
225             "-",
226             "--fast",
227             f"--line-length={black.DEFAULT_LINE_LENGTH}",
228             "--diff",
229             f"--config={config}",
230         ]
231         result = BlackRunner().invoke(
232             black.main, args, input=BytesIO(source.encode("utf8"))
233         )
234         self.assertEqual(result.exit_code, 0)
235         actual = diff_header.sub("[Deterministic header]", result.output)
236         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
237         self.assertEqual(expected, actual)
238
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_setup(self) -> None:
241         source, expected = read_data("../setup", data=False)
242         actual = fs(source)
243         self.assertFormatEqual(expected, actual)
244         black.assert_equivalent(source, actual)
245         black.assert_stable(source, actual, black.FileMode())
246         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_function(self) -> None:
250         source, expected = read_data("function")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, black.FileMode())
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_function2(self) -> None:
258         source, expected = read_data("function2")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, black.FileMode())
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_expression(self) -> None:
266         source, expected = read_data("expression")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, black.FileMode())
271
272     def test_expression_ff(self) -> None:
273         source, expected = read_data("expression")
274         tmp_file = Path(black.dump_to_file(source))
275         try:
276             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
277             with open(tmp_file, encoding="utf8") as f:
278                 actual = f.read()
279         finally:
280             os.unlink(tmp_file)
281         self.assertFormatEqual(expected, actual)
282         with patch("black.dump_to_file", dump_to_stderr):
283             black.assert_equivalent(source, actual)
284             black.assert_stable(source, actual, black.FileMode())
285
286     def test_expression_diff(self) -> None:
287         source, _ = read_data("expression.py")
288         expected, _ = read_data("expression.diff")
289         tmp_file = Path(black.dump_to_file(source))
290         diff_header = re.compile(
291             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
292             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
293         )
294         try:
295             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
296             self.assertEqual(result.exit_code, 0)
297         finally:
298             os.unlink(tmp_file)
299         actual = result.output
300         actual = diff_header.sub("[Deterministic header]", actual)
301         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
302         if expected != actual:
303             dump = black.dump_to_file(actual)
304             msg = (
305                 f"Expected diff isn't equal to the actual. If you made changes "
306                 f"to expression.py and this is an anticipated difference, "
307                 f"overwrite tests/data/expression.diff with {dump}"
308             )
309             self.assertEqual(expected, actual, msg)
310
311     @patch("black.dump_to_file", dump_to_stderr)
312     def test_fstring(self) -> None:
313         source, expected = read_data("fstring")
314         actual = fs(source)
315         self.assertFormatEqual(expected, actual)
316         black.assert_equivalent(source, actual)
317         black.assert_stable(source, actual, black.FileMode())
318
319     @patch("black.dump_to_file", dump_to_stderr)
320     def test_string_quotes(self) -> None:
321         source, expected = read_data("string_quotes")
322         actual = fs(source)
323         self.assertFormatEqual(expected, actual)
324         black.assert_equivalent(source, actual)
325         black.assert_stable(source, actual, black.FileMode())
326         mode = black.FileMode(string_normalization=False)
327         not_normalized = fs(source, mode=mode)
328         self.assertFormatEqual(source, not_normalized)
329         black.assert_equivalent(source, not_normalized)
330         black.assert_stable(source, not_normalized, mode=mode)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_slices(self) -> None:
334         source, expected = read_data("slices")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_equivalent(source, actual)
338         black.assert_stable(source, actual, black.FileMode())
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_comments(self) -> None:
342         source, expected = read_data("comments")
343         actual = fs(source)
344         self.assertFormatEqual(expected, actual)
345         black.assert_equivalent(source, actual)
346         black.assert_stable(source, actual, black.FileMode())
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def test_comments2(self) -> None:
350         source, expected = read_data("comments2")
351         actual = fs(source)
352         self.assertFormatEqual(expected, actual)
353         black.assert_equivalent(source, actual)
354         black.assert_stable(source, actual, black.FileMode())
355
356     @patch("black.dump_to_file", dump_to_stderr)
357     def test_comments3(self) -> None:
358         source, expected = read_data("comments3")
359         actual = fs(source)
360         self.assertFormatEqual(expected, actual)
361         black.assert_equivalent(source, actual)
362         black.assert_stable(source, actual, black.FileMode())
363
364     @patch("black.dump_to_file", dump_to_stderr)
365     def test_comments4(self) -> None:
366         source, expected = read_data("comments4")
367         actual = fs(source)
368         self.assertFormatEqual(expected, actual)
369         black.assert_equivalent(source, actual)
370         black.assert_stable(source, actual, black.FileMode())
371
372     @patch("black.dump_to_file", dump_to_stderr)
373     def test_comments5(self) -> None:
374         source, expected = read_data("comments5")
375         actual = fs(source)
376         self.assertFormatEqual(expected, actual)
377         black.assert_equivalent(source, actual)
378         black.assert_stable(source, actual, black.FileMode())
379
380     @patch("black.dump_to_file", dump_to_stderr)
381     def test_comments6(self) -> None:
382         source, expected = read_data("comments6")
383         actual = fs(source)
384         self.assertFormatEqual(expected, actual)
385         black.assert_equivalent(source, actual)
386         black.assert_stable(source, actual, black.FileMode())
387
388     @patch("black.dump_to_file", dump_to_stderr)
389     def test_cantfit(self) -> None:
390         source, expected = read_data("cantfit")
391         actual = fs(source)
392         self.assertFormatEqual(expected, actual)
393         black.assert_equivalent(source, actual)
394         black.assert_stable(source, actual, black.FileMode())
395
396     @patch("black.dump_to_file", dump_to_stderr)
397     def test_import_spacing(self) -> None:
398         source, expected = read_data("import_spacing")
399         actual = fs(source)
400         self.assertFormatEqual(expected, actual)
401         black.assert_equivalent(source, actual)
402         black.assert_stable(source, actual, black.FileMode())
403
404     @patch("black.dump_to_file", dump_to_stderr)
405     def test_composition(self) -> None:
406         source, expected = read_data("composition")
407         actual = fs(source)
408         self.assertFormatEqual(expected, actual)
409         black.assert_equivalent(source, actual)
410         black.assert_stable(source, actual, black.FileMode())
411
412     @patch("black.dump_to_file", dump_to_stderr)
413     def test_empty_lines(self) -> None:
414         source, expected = read_data("empty_lines")
415         actual = fs(source)
416         self.assertFormatEqual(expected, actual)
417         black.assert_equivalent(source, actual)
418         black.assert_stable(source, actual, black.FileMode())
419
420     @patch("black.dump_to_file", dump_to_stderr)
421     def test_string_prefixes(self) -> None:
422         source, expected = read_data("string_prefixes")
423         actual = fs(source)
424         self.assertFormatEqual(expected, actual)
425         black.assert_equivalent(source, actual)
426         black.assert_stable(source, actual, black.FileMode())
427
428     @patch("black.dump_to_file", dump_to_stderr)
429     def test_numeric_literals(self) -> None:
430         source, expected = read_data("numeric_literals")
431         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
432         actual = fs(source, mode=mode)
433         self.assertFormatEqual(expected, actual)
434         black.assert_equivalent(source, actual)
435         black.assert_stable(source, actual, mode)
436
437     @patch("black.dump_to_file", dump_to_stderr)
438     def test_numeric_literals_ignoring_underscores(self) -> None:
439         source, expected = read_data("numeric_literals_skip_underscores")
440         mode = black.FileMode(
441             numeric_underscore_normalization=False, target_versions=black.PY36_VERSIONS
442         )
443         actual = fs(source, mode=mode)
444         self.assertFormatEqual(expected, actual)
445         black.assert_equivalent(source, actual)
446         black.assert_stable(source, actual, mode)
447
448     @patch("black.dump_to_file", dump_to_stderr)
449     def test_numeric_literals_py2(self) -> None:
450         source, expected = read_data("numeric_literals_py2")
451         actual = fs(source)
452         self.assertFormatEqual(expected, actual)
453         black.assert_stable(source, actual, black.FileMode())
454
455     @patch("black.dump_to_file", dump_to_stderr)
456     def test_python2(self) -> None:
457         source, expected = read_data("python2")
458         actual = fs(source)
459         self.assertFormatEqual(expected, actual)
460         # black.assert_equivalent(source, actual)
461         black.assert_stable(source, actual, black.FileMode())
462
463     @patch("black.dump_to_file", dump_to_stderr)
464     def test_python2_unicode_literals(self) -> None:
465         source, expected = read_data("python2_unicode_literals")
466         actual = fs(source)
467         self.assertFormatEqual(expected, actual)
468         black.assert_stable(source, actual, black.FileMode())
469
470     @patch("black.dump_to_file", dump_to_stderr)
471     def test_stub(self) -> None:
472         mode = black.FileMode(is_pyi=True)
473         source, expected = read_data("stub.pyi")
474         actual = fs(source, mode=mode)
475         self.assertFormatEqual(expected, actual)
476         black.assert_stable(source, actual, mode)
477
478     @patch("black.dump_to_file", dump_to_stderr)
479     def test_python37(self) -> None:
480         source, expected = read_data("python37")
481         actual = fs(source)
482         self.assertFormatEqual(expected, actual)
483         major, minor = sys.version_info[:2]
484         if major > 3 or (major == 3 and minor >= 7):
485             black.assert_equivalent(source, actual)
486         black.assert_stable(source, actual, black.FileMode())
487
488     @patch("black.dump_to_file", dump_to_stderr)
489     def test_fmtonoff(self) -> None:
490         source, expected = read_data("fmtonoff")
491         actual = fs(source)
492         self.assertFormatEqual(expected, actual)
493         black.assert_equivalent(source, actual)
494         black.assert_stable(source, actual, black.FileMode())
495
496     @patch("black.dump_to_file", dump_to_stderr)
497     def test_fmtonoff2(self) -> None:
498         source, expected = read_data("fmtonoff2")
499         actual = fs(source)
500         self.assertFormatEqual(expected, actual)
501         black.assert_equivalent(source, actual)
502         black.assert_stable(source, actual, black.FileMode())
503
504     @patch("black.dump_to_file", dump_to_stderr)
505     def test_remove_empty_parentheses_after_class(self) -> None:
506         source, expected = read_data("class_blank_parentheses")
507         actual = fs(source)
508         self.assertFormatEqual(expected, actual)
509         black.assert_equivalent(source, actual)
510         black.assert_stable(source, actual, black.FileMode())
511
512     @patch("black.dump_to_file", dump_to_stderr)
513     def test_new_line_between_class_and_code(self) -> None:
514         source, expected = read_data("class_methods_new_line")
515         actual = fs(source)
516         self.assertFormatEqual(expected, actual)
517         black.assert_equivalent(source, actual)
518         black.assert_stable(source, actual, black.FileMode())
519
520     @patch("black.dump_to_file", dump_to_stderr)
521     def test_bracket_match(self) -> None:
522         source, expected = read_data("bracketmatch")
523         actual = fs(source)
524         self.assertFormatEqual(expected, actual)
525         black.assert_equivalent(source, actual)
526         black.assert_stable(source, actual, black.FileMode())
527
528     def test_comment_indentation(self) -> None:
529         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
530         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
531
532         self.assertFormatEqual(fs(contents_spc), contents_spc)
533         self.assertFormatEqual(fs(contents_tab), contents_spc)
534
535         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
536         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
537
538         self.assertFormatEqual(fs(contents_tab), contents_spc)
539         self.assertFormatEqual(fs(contents_spc), contents_spc)
540
541     def test_report_verbose(self) -> None:
542         report = black.Report(verbose=True)
543         out_lines = []
544         err_lines = []
545
546         def out(msg: str, **kwargs: Any) -> None:
547             out_lines.append(msg)
548
549         def err(msg: str, **kwargs: Any) -> None:
550             err_lines.append(msg)
551
552         with patch("black.out", out), patch("black.err", err):
553             report.done(Path("f1"), black.Changed.NO)
554             self.assertEqual(len(out_lines), 1)
555             self.assertEqual(len(err_lines), 0)
556             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
557             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
558             self.assertEqual(report.return_code, 0)
559             report.done(Path("f2"), black.Changed.YES)
560             self.assertEqual(len(out_lines), 2)
561             self.assertEqual(len(err_lines), 0)
562             self.assertEqual(out_lines[-1], "reformatted f2")
563             self.assertEqual(
564                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
565             )
566             report.done(Path("f3"), black.Changed.CACHED)
567             self.assertEqual(len(out_lines), 3)
568             self.assertEqual(len(err_lines), 0)
569             self.assertEqual(
570                 out_lines[-1], "f3 wasn't modified on disk since last run."
571             )
572             self.assertEqual(
573                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
574             )
575             self.assertEqual(report.return_code, 0)
576             report.check = True
577             self.assertEqual(report.return_code, 1)
578             report.check = False
579             report.failed(Path("e1"), "boom")
580             self.assertEqual(len(out_lines), 3)
581             self.assertEqual(len(err_lines), 1)
582             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
583             self.assertEqual(
584                 unstyle(str(report)),
585                 "1 file reformatted, 2 files left unchanged, "
586                 "1 file failed to reformat.",
587             )
588             self.assertEqual(report.return_code, 123)
589             report.done(Path("f3"), black.Changed.YES)
590             self.assertEqual(len(out_lines), 4)
591             self.assertEqual(len(err_lines), 1)
592             self.assertEqual(out_lines[-1], "reformatted f3")
593             self.assertEqual(
594                 unstyle(str(report)),
595                 "2 files reformatted, 2 files left unchanged, "
596                 "1 file failed to reformat.",
597             )
598             self.assertEqual(report.return_code, 123)
599             report.failed(Path("e2"), "boom")
600             self.assertEqual(len(out_lines), 4)
601             self.assertEqual(len(err_lines), 2)
602             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
603             self.assertEqual(
604                 unstyle(str(report)),
605                 "2 files reformatted, 2 files left unchanged, "
606                 "2 files failed to reformat.",
607             )
608             self.assertEqual(report.return_code, 123)
609             report.path_ignored(Path("wat"), "no match")
610             self.assertEqual(len(out_lines), 5)
611             self.assertEqual(len(err_lines), 2)
612             self.assertEqual(out_lines[-1], "wat ignored: no match")
613             self.assertEqual(
614                 unstyle(str(report)),
615                 "2 files reformatted, 2 files left unchanged, "
616                 "2 files failed to reformat.",
617             )
618             self.assertEqual(report.return_code, 123)
619             report.done(Path("f4"), black.Changed.NO)
620             self.assertEqual(len(out_lines), 6)
621             self.assertEqual(len(err_lines), 2)
622             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
623             self.assertEqual(
624                 unstyle(str(report)),
625                 "2 files reformatted, 3 files left unchanged, "
626                 "2 files failed to reformat.",
627             )
628             self.assertEqual(report.return_code, 123)
629             report.check = True
630             self.assertEqual(
631                 unstyle(str(report)),
632                 "2 files would be reformatted, 3 files would be left unchanged, "
633                 "2 files would fail to reformat.",
634             )
635
636     def test_report_quiet(self) -> None:
637         report = black.Report(quiet=True)
638         out_lines = []
639         err_lines = []
640
641         def out(msg: str, **kwargs: Any) -> None:
642             out_lines.append(msg)
643
644         def err(msg: str, **kwargs: Any) -> None:
645             err_lines.append(msg)
646
647         with patch("black.out", out), patch("black.err", err):
648             report.done(Path("f1"), black.Changed.NO)
649             self.assertEqual(len(out_lines), 0)
650             self.assertEqual(len(err_lines), 0)
651             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
652             self.assertEqual(report.return_code, 0)
653             report.done(Path("f2"), black.Changed.YES)
654             self.assertEqual(len(out_lines), 0)
655             self.assertEqual(len(err_lines), 0)
656             self.assertEqual(
657                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
658             )
659             report.done(Path("f3"), black.Changed.CACHED)
660             self.assertEqual(len(out_lines), 0)
661             self.assertEqual(len(err_lines), 0)
662             self.assertEqual(
663                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
664             )
665             self.assertEqual(report.return_code, 0)
666             report.check = True
667             self.assertEqual(report.return_code, 1)
668             report.check = False
669             report.failed(Path("e1"), "boom")
670             self.assertEqual(len(out_lines), 0)
671             self.assertEqual(len(err_lines), 1)
672             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
673             self.assertEqual(
674                 unstyle(str(report)),
675                 "1 file reformatted, 2 files left unchanged, "
676                 "1 file failed to reformat.",
677             )
678             self.assertEqual(report.return_code, 123)
679             report.done(Path("f3"), black.Changed.YES)
680             self.assertEqual(len(out_lines), 0)
681             self.assertEqual(len(err_lines), 1)
682             self.assertEqual(
683                 unstyle(str(report)),
684                 "2 files reformatted, 2 files left unchanged, "
685                 "1 file failed to reformat.",
686             )
687             self.assertEqual(report.return_code, 123)
688             report.failed(Path("e2"), "boom")
689             self.assertEqual(len(out_lines), 0)
690             self.assertEqual(len(err_lines), 2)
691             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
692             self.assertEqual(
693                 unstyle(str(report)),
694                 "2 files reformatted, 2 files left unchanged, "
695                 "2 files failed to reformat.",
696             )
697             self.assertEqual(report.return_code, 123)
698             report.path_ignored(Path("wat"), "no match")
699             self.assertEqual(len(out_lines), 0)
700             self.assertEqual(len(err_lines), 2)
701             self.assertEqual(
702                 unstyle(str(report)),
703                 "2 files reformatted, 2 files left unchanged, "
704                 "2 files failed to reformat.",
705             )
706             self.assertEqual(report.return_code, 123)
707             report.done(Path("f4"), black.Changed.NO)
708             self.assertEqual(len(out_lines), 0)
709             self.assertEqual(len(err_lines), 2)
710             self.assertEqual(
711                 unstyle(str(report)),
712                 "2 files reformatted, 3 files left unchanged, "
713                 "2 files failed to reformat.",
714             )
715             self.assertEqual(report.return_code, 123)
716             report.check = True
717             self.assertEqual(
718                 unstyle(str(report)),
719                 "2 files would be reformatted, 3 files would be left unchanged, "
720                 "2 files would fail to reformat.",
721             )
722
723     def test_report_normal(self) -> None:
724         report = black.Report()
725         out_lines = []
726         err_lines = []
727
728         def out(msg: str, **kwargs: Any) -> None:
729             out_lines.append(msg)
730
731         def err(msg: str, **kwargs: Any) -> None:
732             err_lines.append(msg)
733
734         with patch("black.out", out), patch("black.err", err):
735             report.done(Path("f1"), black.Changed.NO)
736             self.assertEqual(len(out_lines), 0)
737             self.assertEqual(len(err_lines), 0)
738             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
739             self.assertEqual(report.return_code, 0)
740             report.done(Path("f2"), black.Changed.YES)
741             self.assertEqual(len(out_lines), 1)
742             self.assertEqual(len(err_lines), 0)
743             self.assertEqual(out_lines[-1], "reformatted f2")
744             self.assertEqual(
745                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
746             )
747             report.done(Path("f3"), black.Changed.CACHED)
748             self.assertEqual(len(out_lines), 1)
749             self.assertEqual(len(err_lines), 0)
750             self.assertEqual(out_lines[-1], "reformatted f2")
751             self.assertEqual(
752                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
753             )
754             self.assertEqual(report.return_code, 0)
755             report.check = True
756             self.assertEqual(report.return_code, 1)
757             report.check = False
758             report.failed(Path("e1"), "boom")
759             self.assertEqual(len(out_lines), 1)
760             self.assertEqual(len(err_lines), 1)
761             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
762             self.assertEqual(
763                 unstyle(str(report)),
764                 "1 file reformatted, 2 files left unchanged, "
765                 "1 file failed to reformat.",
766             )
767             self.assertEqual(report.return_code, 123)
768             report.done(Path("f3"), black.Changed.YES)
769             self.assertEqual(len(out_lines), 2)
770             self.assertEqual(len(err_lines), 1)
771             self.assertEqual(out_lines[-1], "reformatted f3")
772             self.assertEqual(
773                 unstyle(str(report)),
774                 "2 files reformatted, 2 files left unchanged, "
775                 "1 file failed to reformat.",
776             )
777             self.assertEqual(report.return_code, 123)
778             report.failed(Path("e2"), "boom")
779             self.assertEqual(len(out_lines), 2)
780             self.assertEqual(len(err_lines), 2)
781             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
782             self.assertEqual(
783                 unstyle(str(report)),
784                 "2 files reformatted, 2 files left unchanged, "
785                 "2 files failed to reformat.",
786             )
787             self.assertEqual(report.return_code, 123)
788             report.path_ignored(Path("wat"), "no match")
789             self.assertEqual(len(out_lines), 2)
790             self.assertEqual(len(err_lines), 2)
791             self.assertEqual(
792                 unstyle(str(report)),
793                 "2 files reformatted, 2 files left unchanged, "
794                 "2 files failed to reformat.",
795             )
796             self.assertEqual(report.return_code, 123)
797             report.done(Path("f4"), black.Changed.NO)
798             self.assertEqual(len(out_lines), 2)
799             self.assertEqual(len(err_lines), 2)
800             self.assertEqual(
801                 unstyle(str(report)),
802                 "2 files reformatted, 3 files left unchanged, "
803                 "2 files failed to reformat.",
804             )
805             self.assertEqual(report.return_code, 123)
806             report.check = True
807             self.assertEqual(
808                 unstyle(str(report)),
809                 "2 files would be reformatted, 3 files would be left unchanged, "
810                 "2 files would fail to reformat.",
811             )
812
813     def test_get_features_used(self) -> None:
814         node = black.lib2to3_parse("def f(*, arg): ...\n")
815         self.assertEqual(black.get_features_used(node), set())
816         node = black.lib2to3_parse("def f(*, arg,): ...\n")
817         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
818         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
819         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
820         node = black.lib2to3_parse("123_456\n")
821         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
822         node = black.lib2to3_parse("123456\n")
823         self.assertEqual(black.get_features_used(node), set())
824         source, expected = read_data("function")
825         node = black.lib2to3_parse(source)
826         self.assertEqual(
827             black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
828         )
829         node = black.lib2to3_parse(expected)
830         self.assertEqual(
831             black.get_features_used(node),
832             {Feature.TRAILING_COMMA, Feature.F_STRINGS, Feature.NUMERIC_UNDERSCORES},
833         )
834         source, expected = read_data("expression")
835         node = black.lib2to3_parse(source)
836         self.assertEqual(black.get_features_used(node), set())
837         node = black.lib2to3_parse(expected)
838         self.assertEqual(black.get_features_used(node), set())
839
840     def test_get_future_imports(self) -> None:
841         node = black.lib2to3_parse("\n")
842         self.assertEqual(set(), black.get_future_imports(node))
843         node = black.lib2to3_parse("from __future__ import black\n")
844         self.assertEqual({"black"}, black.get_future_imports(node))
845         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
846         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
847         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
848         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
849         node = black.lib2to3_parse(
850             "from __future__ import multiple\nfrom __future__ import imports\n"
851         )
852         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
853         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
854         self.assertEqual({"black"}, black.get_future_imports(node))
855         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
856         self.assertEqual({"black"}, black.get_future_imports(node))
857         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
858         self.assertEqual(set(), black.get_future_imports(node))
859         node = black.lib2to3_parse("from some.module import black\n")
860         self.assertEqual(set(), black.get_future_imports(node))
861         node = black.lib2to3_parse(
862             "from __future__ import unicode_literals as _unicode_literals"
863         )
864         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
865         node = black.lib2to3_parse(
866             "from __future__ import unicode_literals as _lol, print"
867         )
868         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
869
870     def test_debug_visitor(self) -> None:
871         source, _ = read_data("debug_visitor.py")
872         expected, _ = read_data("debug_visitor.out")
873         out_lines = []
874         err_lines = []
875
876         def out(msg: str, **kwargs: Any) -> None:
877             out_lines.append(msg)
878
879         def err(msg: str, **kwargs: Any) -> None:
880             err_lines.append(msg)
881
882         with patch("black.out", out), patch("black.err", err):
883             black.DebugVisitor.show(source)
884         actual = "\n".join(out_lines) + "\n"
885         log_name = ""
886         if expected != actual:
887             log_name = black.dump_to_file(*out_lines)
888         self.assertEqual(
889             expected,
890             actual,
891             f"AST print out is different. Actual version dumped to {log_name}",
892         )
893
894     def test_format_file_contents(self) -> None:
895         empty = ""
896         mode = black.FileMode()
897         with self.assertRaises(black.NothingChanged):
898             black.format_file_contents(empty, mode=mode, fast=False)
899         just_nl = "\n"
900         with self.assertRaises(black.NothingChanged):
901             black.format_file_contents(just_nl, mode=mode, fast=False)
902         same = "l = [1, 2, 3]\n"
903         with self.assertRaises(black.NothingChanged):
904             black.format_file_contents(same, mode=mode, fast=False)
905         different = "l = [1,2,3]"
906         expected = same
907         actual = black.format_file_contents(different, mode=mode, fast=False)
908         self.assertEqual(expected, actual)
909         invalid = "return if you can"
910         with self.assertRaises(black.InvalidInput) as e:
911             black.format_file_contents(invalid, mode=mode, fast=False)
912         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
913
914     def test_endmarker(self) -> None:
915         n = black.lib2to3_parse("\n")
916         self.assertEqual(n.type, black.syms.file_input)
917         self.assertEqual(len(n.children), 1)
918         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
919
920     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
921     def test_assertFormatEqual(self) -> None:
922         out_lines = []
923         err_lines = []
924
925         def out(msg: str, **kwargs: Any) -> None:
926             out_lines.append(msg)
927
928         def err(msg: str, **kwargs: Any) -> None:
929             err_lines.append(msg)
930
931         with patch("black.out", out), patch("black.err", err):
932             with self.assertRaises(AssertionError):
933                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
934
935         out_str = "".join(out_lines)
936         self.assertTrue("Expected tree:" in out_str)
937         self.assertTrue("Actual tree:" in out_str)
938         self.assertEqual("".join(err_lines), "")
939
940     def test_cache_broken_file(self) -> None:
941         mode = black.FileMode()
942         with cache_dir() as workspace:
943             cache_file = black.get_cache_file(mode)
944             with cache_file.open("w") as fobj:
945                 fobj.write("this is not a pickle")
946             self.assertEqual(black.read_cache(mode), {})
947             src = (workspace / "test.py").resolve()
948             with src.open("w") as fobj:
949                 fobj.write("print('hello')")
950             self.invokeBlack([str(src)])
951             cache = black.read_cache(mode)
952             self.assertIn(src, cache)
953
954     def test_cache_single_file_already_cached(self) -> None:
955         mode = black.FileMode()
956         with cache_dir() as workspace:
957             src = (workspace / "test.py").resolve()
958             with src.open("w") as fobj:
959                 fobj.write("print('hello')")
960             black.write_cache({}, [src], mode)
961             self.invokeBlack([str(src)])
962             with src.open("r") as fobj:
963                 self.assertEqual(fobj.read(), "print('hello')")
964
965     @event_loop(close=False)
966     def test_cache_multiple_files(self) -> None:
967         mode = black.FileMode()
968         with cache_dir() as workspace, patch(
969             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
970         ):
971             one = (workspace / "one.py").resolve()
972             with one.open("w") as fobj:
973                 fobj.write("print('hello')")
974             two = (workspace / "two.py").resolve()
975             with two.open("w") as fobj:
976                 fobj.write("print('hello')")
977             black.write_cache({}, [one], mode)
978             self.invokeBlack([str(workspace)])
979             with one.open("r") as fobj:
980                 self.assertEqual(fobj.read(), "print('hello')")
981             with two.open("r") as fobj:
982                 self.assertEqual(fobj.read(), 'print("hello")\n')
983             cache = black.read_cache(mode)
984             self.assertIn(one, cache)
985             self.assertIn(two, cache)
986
987     def test_no_cache_when_writeback_diff(self) -> None:
988         mode = black.FileMode()
989         with cache_dir() as workspace:
990             src = (workspace / "test.py").resolve()
991             with src.open("w") as fobj:
992                 fobj.write("print('hello')")
993             self.invokeBlack([str(src), "--diff"])
994             cache_file = black.get_cache_file(mode)
995             self.assertFalse(cache_file.exists())
996
997     def test_no_cache_when_stdin(self) -> None:
998         mode = black.FileMode()
999         with cache_dir():
1000             result = CliRunner().invoke(
1001                 black.main, ["-"], input=BytesIO(b"print('hello')")
1002             )
1003             self.assertEqual(result.exit_code, 0)
1004             cache_file = black.get_cache_file(mode)
1005             self.assertFalse(cache_file.exists())
1006
1007     def test_read_cache_no_cachefile(self) -> None:
1008         mode = black.FileMode()
1009         with cache_dir():
1010             self.assertEqual(black.read_cache(mode), {})
1011
1012     def test_write_cache_read_cache(self) -> None:
1013         mode = black.FileMode()
1014         with cache_dir() as workspace:
1015             src = (workspace / "test.py").resolve()
1016             src.touch()
1017             black.write_cache({}, [src], mode)
1018             cache = black.read_cache(mode)
1019             self.assertIn(src, cache)
1020             self.assertEqual(cache[src], black.get_cache_info(src))
1021
1022     def test_filter_cached(self) -> None:
1023         with TemporaryDirectory() as workspace:
1024             path = Path(workspace)
1025             uncached = (path / "uncached").resolve()
1026             cached = (path / "cached").resolve()
1027             cached_but_changed = (path / "changed").resolve()
1028             uncached.touch()
1029             cached.touch()
1030             cached_but_changed.touch()
1031             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1032             todo, done = black.filter_cached(
1033                 cache, {uncached, cached, cached_but_changed}
1034             )
1035             self.assertEqual(todo, {uncached, cached_but_changed})
1036             self.assertEqual(done, {cached})
1037
1038     def test_write_cache_creates_directory_if_needed(self) -> None:
1039         mode = black.FileMode()
1040         with cache_dir(exists=False) as workspace:
1041             self.assertFalse(workspace.exists())
1042             black.write_cache({}, [], mode)
1043             self.assertTrue(workspace.exists())
1044
1045     @event_loop(close=False)
1046     def test_failed_formatting_does_not_get_cached(self) -> None:
1047         mode = black.FileMode()
1048         with cache_dir() as workspace, patch(
1049             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1050         ):
1051             failing = (workspace / "failing.py").resolve()
1052             with failing.open("w") as fobj:
1053                 fobj.write("not actually python")
1054             clean = (workspace / "clean.py").resolve()
1055             with clean.open("w") as fobj:
1056                 fobj.write('print("hello")\n')
1057             self.invokeBlack([str(workspace)], exit_code=123)
1058             cache = black.read_cache(mode)
1059             self.assertNotIn(failing, cache)
1060             self.assertIn(clean, cache)
1061
1062     def test_write_cache_write_fail(self) -> None:
1063         mode = black.FileMode()
1064         with cache_dir(), patch.object(Path, "open") as mock:
1065             mock.side_effect = OSError
1066             black.write_cache({}, [], mode)
1067
1068     @event_loop(close=False)
1069     def test_check_diff_use_together(self) -> None:
1070         with cache_dir():
1071             # Files which will be reformatted.
1072             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1073             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1074             # Files which will not be reformatted.
1075             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1076             self.invokeBlack([str(src2), "--diff", "--check"])
1077             # Multi file command.
1078             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1079
1080     def test_no_files(self) -> None:
1081         with cache_dir():
1082             # Without an argument, black exits with error code 0.
1083             self.invokeBlack([])
1084
1085     def test_broken_symlink(self) -> None:
1086         with cache_dir() as workspace:
1087             symlink = workspace / "broken_link.py"
1088             try:
1089                 symlink.symlink_to("nonexistent.py")
1090             except OSError as e:
1091                 self.skipTest(f"Can't create symlinks: {e}")
1092             self.invokeBlack([str(workspace.resolve())])
1093
1094     def test_read_cache_line_lengths(self) -> None:
1095         mode = black.FileMode()
1096         short_mode = black.FileMode(line_length=1)
1097         with cache_dir() as workspace:
1098             path = (workspace / "file.py").resolve()
1099             path.touch()
1100             black.write_cache({}, [path], mode)
1101             one = black.read_cache(mode)
1102             self.assertIn(path, one)
1103             two = black.read_cache(short_mode)
1104             self.assertNotIn(path, two)
1105
1106     def test_single_file_force_pyi(self) -> None:
1107         reg_mode = black.FileMode()
1108         pyi_mode = black.FileMode(is_pyi=True)
1109         contents, expected = read_data("force_pyi")
1110         with cache_dir() as workspace:
1111             path = (workspace / "file.py").resolve()
1112             with open(path, "w") as fh:
1113                 fh.write(contents)
1114             self.invokeBlack([str(path), "--pyi"])
1115             with open(path, "r") as fh:
1116                 actual = fh.read()
1117             # verify cache with --pyi is separate
1118             pyi_cache = black.read_cache(pyi_mode)
1119             self.assertIn(path, pyi_cache)
1120             normal_cache = black.read_cache(reg_mode)
1121             self.assertNotIn(path, normal_cache)
1122         self.assertEqual(actual, expected)
1123
1124     @event_loop(close=False)
1125     def test_multi_file_force_pyi(self) -> None:
1126         reg_mode = black.FileMode()
1127         pyi_mode = black.FileMode(is_pyi=True)
1128         contents, expected = read_data("force_pyi")
1129         with cache_dir() as workspace:
1130             paths = [
1131                 (workspace / "file1.py").resolve(),
1132                 (workspace / "file2.py").resolve(),
1133             ]
1134             for path in paths:
1135                 with open(path, "w") as fh:
1136                     fh.write(contents)
1137             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1138             for path in paths:
1139                 with open(path, "r") as fh:
1140                     actual = fh.read()
1141                 self.assertEqual(actual, expected)
1142             # verify cache with --pyi is separate
1143             pyi_cache = black.read_cache(pyi_mode)
1144             normal_cache = black.read_cache(reg_mode)
1145             for path in paths:
1146                 self.assertIn(path, pyi_cache)
1147                 self.assertNotIn(path, normal_cache)
1148
1149     def test_pipe_force_pyi(self) -> None:
1150         source, expected = read_data("force_pyi")
1151         result = CliRunner().invoke(
1152             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1153         )
1154         self.assertEqual(result.exit_code, 0)
1155         actual = result.output
1156         self.assertFormatEqual(actual, expected)
1157
1158     def test_single_file_force_py36(self) -> None:
1159         reg_mode = black.FileMode()
1160         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1161         source, expected = read_data("force_py36")
1162         with cache_dir() as workspace:
1163             path = (workspace / "file.py").resolve()
1164             with open(path, "w") as fh:
1165                 fh.write(source)
1166             self.invokeBlack([str(path), "--py36"])
1167             with open(path, "r") as fh:
1168                 actual = fh.read()
1169             # verify cache with --py36 is separate
1170             py36_cache = black.read_cache(py36_mode)
1171             self.assertIn(path, py36_cache)
1172             normal_cache = black.read_cache(reg_mode)
1173             self.assertNotIn(path, normal_cache)
1174         self.assertEqual(actual, expected)
1175
1176     @event_loop(close=False)
1177     def test_multi_file_force_py36(self) -> None:
1178         reg_mode = black.FileMode()
1179         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1180         source, expected = read_data("force_py36")
1181         with cache_dir() as workspace:
1182             paths = [
1183                 (workspace / "file1.py").resolve(),
1184                 (workspace / "file2.py").resolve(),
1185             ]
1186             for path in paths:
1187                 with open(path, "w") as fh:
1188                     fh.write(source)
1189             self.invokeBlack([str(p) for p in paths] + ["--py36"])
1190             for path in paths:
1191                 with open(path, "r") as fh:
1192                     actual = fh.read()
1193                 self.assertEqual(actual, expected)
1194             # verify cache with --py36 is separate
1195             pyi_cache = black.read_cache(py36_mode)
1196             normal_cache = black.read_cache(reg_mode)
1197             for path in paths:
1198                 self.assertIn(path, pyi_cache)
1199                 self.assertNotIn(path, normal_cache)
1200
1201     def test_pipe_force_py36(self) -> None:
1202         source, expected = read_data("force_py36")
1203         result = CliRunner().invoke(
1204             black.main, ["-", "-q", "--py36"], input=BytesIO(source.encode("utf8"))
1205         )
1206         self.assertEqual(result.exit_code, 0)
1207         actual = result.output
1208         self.assertFormatEqual(actual, expected)
1209
1210     def test_include_exclude(self) -> None:
1211         path = THIS_DIR / "data" / "include_exclude_tests"
1212         include = re.compile(r"\.pyi?$")
1213         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1214         report = black.Report()
1215         sources: List[Path] = []
1216         expected = [
1217             Path(path / "b/dont_exclude/a.py"),
1218             Path(path / "b/dont_exclude/a.pyi"),
1219         ]
1220         this_abs = THIS_DIR.resolve()
1221         sources.extend(
1222             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1223         )
1224         self.assertEqual(sorted(expected), sorted(sources))
1225
1226     def test_empty_include(self) -> None:
1227         path = THIS_DIR / "data" / "include_exclude_tests"
1228         report = black.Report()
1229         empty = re.compile(r"")
1230         sources: List[Path] = []
1231         expected = [
1232             Path(path / "b/exclude/a.pie"),
1233             Path(path / "b/exclude/a.py"),
1234             Path(path / "b/exclude/a.pyi"),
1235             Path(path / "b/dont_exclude/a.pie"),
1236             Path(path / "b/dont_exclude/a.py"),
1237             Path(path / "b/dont_exclude/a.pyi"),
1238             Path(path / "b/.definitely_exclude/a.pie"),
1239             Path(path / "b/.definitely_exclude/a.py"),
1240             Path(path / "b/.definitely_exclude/a.pyi"),
1241         ]
1242         this_abs = THIS_DIR.resolve()
1243         sources.extend(
1244             black.gen_python_files_in_dir(
1245                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1246             )
1247         )
1248         self.assertEqual(sorted(expected), sorted(sources))
1249
1250     def test_empty_exclude(self) -> None:
1251         path = THIS_DIR / "data" / "include_exclude_tests"
1252         report = black.Report()
1253         empty = re.compile(r"")
1254         sources: List[Path] = []
1255         expected = [
1256             Path(path / "b/dont_exclude/a.py"),
1257             Path(path / "b/dont_exclude/a.pyi"),
1258             Path(path / "b/exclude/a.py"),
1259             Path(path / "b/exclude/a.pyi"),
1260             Path(path / "b/.definitely_exclude/a.py"),
1261             Path(path / "b/.definitely_exclude/a.pyi"),
1262         ]
1263         this_abs = THIS_DIR.resolve()
1264         sources.extend(
1265             black.gen_python_files_in_dir(
1266                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1267             )
1268         )
1269         self.assertEqual(sorted(expected), sorted(sources))
1270
1271     def test_invalid_include_exclude(self) -> None:
1272         for option in ["--include", "--exclude"]:
1273             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1274
1275     def test_preserves_line_endings(self) -> None:
1276         with TemporaryDirectory() as workspace:
1277             test_file = Path(workspace) / "test.py"
1278             for nl in ["\n", "\r\n"]:
1279                 contents = nl.join(["def f(  ):", "    pass"])
1280                 test_file.write_bytes(contents.encode())
1281                 ff(test_file, write_back=black.WriteBack.YES)
1282                 updated_contents: bytes = test_file.read_bytes()
1283                 self.assertIn(nl.encode(), updated_contents)
1284                 if nl == "\n":
1285                     self.assertNotIn(b"\r\n", updated_contents)
1286
1287     def test_preserves_line_endings_via_stdin(self) -> None:
1288         for nl in ["\n", "\r\n"]:
1289             contents = nl.join(["def f(  ):", "    pass"])
1290             runner = BlackRunner()
1291             result = runner.invoke(
1292                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1293             )
1294             self.assertEqual(result.exit_code, 0)
1295             output = runner.stdout_bytes
1296             self.assertIn(nl.encode("utf8"), output)
1297             if nl == "\n":
1298                 self.assertNotIn(b"\r\n", output)
1299
1300     def test_assert_equivalent_different_asts(self) -> None:
1301         with self.assertRaises(AssertionError):
1302             black.assert_equivalent("{}", "None")
1303
1304     def test_symlink_out_of_root_directory(self) -> None:
1305         path = MagicMock()
1306         root = THIS_DIR
1307         child = MagicMock()
1308         include = re.compile(black.DEFAULT_INCLUDES)
1309         exclude = re.compile(black.DEFAULT_EXCLUDES)
1310         report = black.Report()
1311         # `child` should behave like a symlink which resolved path is clearly
1312         # outside of the `root` directory.
1313         path.iterdir.return_value = [child]
1314         child.resolve.return_value = Path("/a/b/c")
1315         child.is_symlink.return_value = True
1316         try:
1317             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1318         except ValueError as ve:
1319             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1320         path.iterdir.assert_called_once()
1321         child.resolve.assert_called_once()
1322         child.is_symlink.assert_called_once()
1323         # `child` should behave like a strange file which resolved path is clearly
1324         # outside of the `root` directory.
1325         child.is_symlink.return_value = False
1326         with self.assertRaises(ValueError):
1327             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1328         path.iterdir.assert_called()
1329         self.assertEqual(path.iterdir.call_count, 2)
1330         child.resolve.assert_called()
1331         self.assertEqual(child.resolve.call_count, 2)
1332         child.is_symlink.assert_called()
1333         self.assertEqual(child.is_symlink.call_count, 2)
1334
1335     def test_shhh_click(self) -> None:
1336         try:
1337             from click import _unicodefun  # type: ignore
1338         except ModuleNotFoundError:
1339             self.skipTest("Incompatible Click version")
1340         if not hasattr(_unicodefun, "_verify_python3_env"):
1341             self.skipTest("Incompatible Click version")
1342         # First, let's see if Click is crashing with a preferred ASCII charset.
1343         with patch("locale.getpreferredencoding") as gpe:
1344             gpe.return_value = "ASCII"
1345             with self.assertRaises(RuntimeError):
1346                 _unicodefun._verify_python3_env()
1347         # Now, let's silence Click...
1348         black.patch_click()
1349         # ...and confirm it's silent.
1350         with patch("locale.getpreferredencoding") as gpe:
1351             gpe.return_value = "ASCII"
1352             try:
1353                 _unicodefun._verify_python3_env()
1354             except RuntimeError as re:
1355                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1356
1357     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1358     @async_test
1359     async def test_blackd_request_needs_formatting(self) -> None:
1360         app = blackd.make_app()
1361         async with TestClient(TestServer(app)) as client:
1362             response = await client.post("/", data=b"print('hello world')")
1363             self.assertEqual(response.status, 200)
1364             self.assertEqual(response.charset, "utf8")
1365             self.assertEqual(await response.read(), b'print("hello world")\n')
1366
1367     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1368     @async_test
1369     async def test_blackd_request_no_change(self) -> None:
1370         app = blackd.make_app()
1371         async with TestClient(TestServer(app)) as client:
1372             response = await client.post("/", data=b'print("hello world")\n')
1373             self.assertEqual(response.status, 204)
1374             self.assertEqual(await response.read(), b"")
1375
1376     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1377     @async_test
1378     async def test_blackd_request_syntax_error(self) -> None:
1379         app = blackd.make_app()
1380         async with TestClient(TestServer(app)) as client:
1381             response = await client.post("/", data=b"what even ( is")
1382             self.assertEqual(response.status, 400)
1383             content = await response.text()
1384             self.assertTrue(
1385                 content.startswith("Cannot parse"),
1386                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1387             )
1388
1389     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1390     @async_test
1391     async def test_blackd_unsupported_version(self) -> None:
1392         app = blackd.make_app()
1393         async with TestClient(TestServer(app)) as client:
1394             response = await client.post(
1395                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1396             )
1397             self.assertEqual(response.status, 501)
1398
1399     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1400     @async_test
1401     async def test_blackd_supported_version(self) -> None:
1402         app = blackd.make_app()
1403         async with TestClient(TestServer(app)) as client:
1404             response = await client.post(
1405                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1406             )
1407             self.assertEqual(response.status, 200)
1408
1409     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1410     @async_test
1411     async def test_blackd_invalid_python_variant(self) -> None:
1412         app = blackd.make_app()
1413         async with TestClient(TestServer(app)) as client:
1414
1415             async def check(header_value: str, expected_status: int = 400) -> None:
1416                 response = await client.post(
1417                     "/",
1418                     data=b"what",
1419                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1420                 )
1421                 self.assertEqual(response.status, expected_status)
1422
1423             await check("lol")
1424             await check("ruby3.5")
1425             await check("pyi3.6")
1426             await check("cpy1.5")
1427             await check("2.8")
1428             await check("cpy2.8")
1429             await check("3.0")
1430             await check("pypy3.0")
1431             await check("jython3.4")
1432
1433     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1434     @async_test
1435     async def test_blackd_pyi(self) -> None:
1436         app = blackd.make_app()
1437         async with TestClient(TestServer(app)) as client:
1438             source, expected = read_data("stub.pyi")
1439             response = await client.post(
1440                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1441             )
1442             self.assertEqual(response.status, 200)
1443             self.assertEqual(await response.text(), expected)
1444
1445     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1446     @async_test
1447     async def test_blackd_python_variant(self) -> None:
1448         app = blackd.make_app()
1449         code = (
1450             "def f(\n"
1451             "    and_has_a_bunch_of,\n"
1452             "    very_long_arguments_too,\n"
1453             "    and_lots_of_them_as_well_lol,\n"
1454             "    **and_very_long_keyword_arguments\n"
1455             "):\n"
1456             "    pass\n"
1457         )
1458         async with TestClient(TestServer(app)) as client:
1459
1460             async def check(header_value: str, expected_status: int) -> None:
1461                 response = await client.post(
1462                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1463                 )
1464                 self.assertEqual(response.status, expected_status)
1465
1466             await check("3.6", 200)
1467             await check("cpy3.6", 200)
1468             await check("3.5,3.7", 200)
1469             await check("3.5,cpy3.7", 200)
1470
1471             await check("2", 204)
1472             await check("2.7", 204)
1473             await check("cpy2.7", 204)
1474             await check("pypy2.7", 204)
1475             await check("3.4", 204)
1476             await check("cpy3.4", 204)
1477             await check("pypy3.4", 204)
1478
1479     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1480     @async_test
1481     async def test_blackd_fast(self) -> None:
1482         with open(os.devnull, "w") as dn, redirect_stderr(dn):
1483             app = blackd.make_app()
1484             async with TestClient(TestServer(app)) as client:
1485                 response = await client.post("/", data=b"ur'hello'")
1486                 self.assertEqual(response.status, 500)
1487                 self.assertIn("failed to parse source file", await response.text())
1488                 response = await client.post(
1489                     "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1490                 )
1491                 self.assertEqual(response.status, 200)
1492
1493     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1494     @async_test
1495     async def test_blackd_line_length(self) -> None:
1496         app = blackd.make_app()
1497         async with TestClient(TestServer(app)) as client:
1498             response = await client.post(
1499                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1500             )
1501             self.assertEqual(response.status, 200)
1502
1503     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1504     @async_test
1505     async def test_blackd_invalid_line_length(self) -> None:
1506         app = blackd.make_app()
1507         async with TestClient(TestServer(app)) as client:
1508             response = await client.post(
1509                 "/",
1510                 data=b'print("hello")\n',
1511                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1512             )
1513             self.assertEqual(response.status, 400)
1514
1515     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1516     def test_blackd_main(self) -> None:
1517         with patch("blackd.web.run_app"):
1518             result = CliRunner().invoke(blackd.main, [])
1519             if result.exception is not None:
1520                 raise result.exception
1521             self.assertEqual(result.exit_code, 0)
1522
1523
1524 if __name__ == "__main__":
1525     unittest.main(module="test_black")