]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

3404e058e1b069abc46830e41953bb99feeee169
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager, redirect_stderr
5 from functools import partial, wraps
6 from io import BytesIO, TextIOWrapper
7 import os
8 from pathlib import Path
9 import re
10 import sys
11 from tempfile import TemporaryDirectory
12 from typing import (
13     Any,
14     BinaryIO,
15     Callable,
16     Coroutine,
17     Generator,
18     List,
19     Tuple,
20     Iterator,
21     TypeVar,
22 )
23 import unittest
24 from unittest.mock import patch, MagicMock
25
26 from click import unstyle
27 from click.testing import CliRunner
28
29 import black
30 from black import Feature
31
32 try:
33     import blackd
34     from aiohttp.test_utils import TestClient, TestServer
35 except ImportError:
36     has_blackd_deps = False
37 else:
38     has_blackd_deps = True
39
40
41 ff = partial(black.format_file_in_place, mode=black.FileMode(), fast=True)
42 fs = partial(black.format_str, mode=black.FileMode())
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 T = TypeVar("T")
47 R = TypeVar("R")
48
49
50 def dump_to_stderr(*output: str) -> str:
51     return "\n" + "\n".join(output) + "\n"
52
53
54 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
55     """read_data('test_name') -> 'input', 'output'"""
56     if not name.endswith((".py", ".pyi", ".out", ".diff")):
57         name += ".py"
58     _input: List[str] = []
59     _output: List[str] = []
60     base_dir = THIS_DIR / "data" if data else THIS_DIR
61     with open(base_dir / name, "r", encoding="utf8") as test:
62         lines = test.readlines()
63     result = _input
64     for line in lines:
65         line = line.replace(EMPTY_LINE, "")
66         if line.rstrip() == "# output":
67             result = _output
68             continue
69
70         result.append(line)
71     if _input and not _output:
72         # If there's no output marker, treat the entire file as already pre-formatted.
73         _output = _input[:]
74     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
75
76
77 @contextmanager
78 def cache_dir(exists: bool = True) -> Iterator[Path]:
79     with TemporaryDirectory() as workspace:
80         cache_dir = Path(workspace)
81         if not exists:
82             cache_dir = cache_dir / "new"
83         with patch("black.CACHE_DIR", cache_dir):
84             yield cache_dir
85
86
87 @contextmanager
88 def event_loop(close: bool) -> Iterator[None]:
89     policy = asyncio.get_event_loop_policy()
90     old_loop = policy.get_event_loop()
91     loop = policy.new_event_loop()
92     asyncio.set_event_loop(loop)
93     try:
94         yield
95
96     finally:
97         policy.set_event_loop(old_loop)
98         if close:
99             loop.close()
100
101
102 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
103     @event_loop(close=True)
104     @wraps(f)
105     def wrapper(*args: Any, **kwargs: Any) -> None:
106         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
107
108     return wrapper
109
110
111 class BlackRunner(CliRunner):
112     """Modify CliRunner so that stderr is not merged with stdout.
113
114     This is a hack that can be removed once we depend on Click 7.x"""
115
116     def __init__(self) -> None:
117         self.stderrbuf = BytesIO()
118         self.stdoutbuf = BytesIO()
119         self.stdout_bytes = b""
120         self.stderr_bytes = b""
121         super().__init__()
122
123     @contextmanager
124     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
125         with super().isolation(*args, **kwargs) as output:
126             try:
127                 hold_stderr = sys.stderr
128                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
129                 yield output
130             finally:
131                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
132                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
133                 sys.stderr = hold_stderr
134
135
136 class BlackTestCase(unittest.TestCase):
137     maxDiff = None
138
139     def assertFormatEqual(self, expected: str, actual: str) -> None:
140         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
141             bdv: black.DebugVisitor[Any]
142             black.out("Expected tree:", fg="green")
143             try:
144                 exp_node = black.lib2to3_parse(expected)
145                 bdv = black.DebugVisitor()
146                 list(bdv.visit(exp_node))
147             except Exception as ve:
148                 black.err(str(ve))
149             black.out("Actual tree:", fg="red")
150             try:
151                 exp_node = black.lib2to3_parse(actual)
152                 bdv = black.DebugVisitor()
153                 list(bdv.visit(exp_node))
154             except Exception as ve:
155                 black.err(str(ve))
156         self.assertEqual(expected, actual)
157
158     def invokeBlack(
159         self, args: List[str], exit_code: int = 0, ignore_config: bool = True
160     ) -> None:
161         runner = BlackRunner()
162         if ignore_config:
163             args = ["--config", str(THIS_DIR / "empty.toml"), *args]
164         result = runner.invoke(black.main, args)
165         self.assertEqual(result.exit_code, exit_code, msg=runner.stderr_bytes.decode())
166
167     @patch("black.dump_to_file", dump_to_stderr)
168     def test_empty(self) -> None:
169         source = expected = ""
170         actual = fs(source)
171         self.assertFormatEqual(expected, actual)
172         black.assert_equivalent(source, actual)
173         black.assert_stable(source, actual, black.FileMode())
174
175     def test_empty_ff(self) -> None:
176         expected = ""
177         tmp_file = Path(black.dump_to_file())
178         try:
179             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
180             with open(tmp_file, encoding="utf8") as f:
181                 actual = f.read()
182         finally:
183             os.unlink(tmp_file)
184         self.assertFormatEqual(expected, actual)
185
186     @patch("black.dump_to_file", dump_to_stderr)
187     def test_self(self) -> None:
188         source, expected = read_data("test_black", data=False)
189         actual = fs(source)
190         self.assertFormatEqual(expected, actual)
191         black.assert_equivalent(source, actual)
192         black.assert_stable(source, actual, black.FileMode())
193         self.assertFalse(ff(THIS_FILE))
194
195     @patch("black.dump_to_file", dump_to_stderr)
196     def test_black(self) -> None:
197         source, expected = read_data("../black", data=False)
198         actual = fs(source)
199         self.assertFormatEqual(expected, actual)
200         black.assert_equivalent(source, actual)
201         black.assert_stable(source, actual, black.FileMode())
202         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
203
204     def test_piping(self) -> None:
205         source, expected = read_data("../black", data=False)
206         result = BlackRunner().invoke(
207             black.main,
208             ["-", "--fast", f"--line-length={black.DEFAULT_LINE_LENGTH}"],
209             input=BytesIO(source.encode("utf8")),
210         )
211         self.assertEqual(result.exit_code, 0)
212         self.assertFormatEqual(expected, result.output)
213         black.assert_equivalent(source, result.output)
214         black.assert_stable(source, result.output, black.FileMode())
215
216     def test_piping_diff(self) -> None:
217         diff_header = re.compile(
218             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
219             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
220         )
221         source, _ = read_data("expression.py")
222         expected, _ = read_data("expression.diff")
223         config = THIS_DIR / "data" / "empty_pyproject.toml"
224         args = [
225             "-",
226             "--fast",
227             f"--line-length={black.DEFAULT_LINE_LENGTH}",
228             "--diff",
229             f"--config={config}",
230         ]
231         result = BlackRunner().invoke(
232             black.main, args, input=BytesIO(source.encode("utf8"))
233         )
234         self.assertEqual(result.exit_code, 0)
235         actual = diff_header.sub("[Deterministic header]", result.output)
236         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
237         self.assertEqual(expected, actual)
238
239     @patch("black.dump_to_file", dump_to_stderr)
240     def test_setup(self) -> None:
241         source, expected = read_data("../setup", data=False)
242         actual = fs(source)
243         self.assertFormatEqual(expected, actual)
244         black.assert_equivalent(source, actual)
245         black.assert_stable(source, actual, black.FileMode())
246         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_function(self) -> None:
250         source, expected = read_data("function")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, black.FileMode())
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_function2(self) -> None:
258         source, expected = read_data("function2")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, black.FileMode())
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_expression(self) -> None:
266         source, expected = read_data("expression")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, black.FileMode())
271
272     def test_expression_ff(self) -> None:
273         source, expected = read_data("expression")
274         tmp_file = Path(black.dump_to_file(source))
275         try:
276             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
277             with open(tmp_file, encoding="utf8") as f:
278                 actual = f.read()
279         finally:
280             os.unlink(tmp_file)
281         self.assertFormatEqual(expected, actual)
282         with patch("black.dump_to_file", dump_to_stderr):
283             black.assert_equivalent(source, actual)
284             black.assert_stable(source, actual, black.FileMode())
285
286     def test_expression_diff(self) -> None:
287         source, _ = read_data("expression.py")
288         expected, _ = read_data("expression.diff")
289         tmp_file = Path(black.dump_to_file(source))
290         diff_header = re.compile(
291             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
292             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
293         )
294         try:
295             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
296             self.assertEqual(result.exit_code, 0)
297         finally:
298             os.unlink(tmp_file)
299         actual = result.output
300         actual = diff_header.sub("[Deterministic header]", actual)
301         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
302         if expected != actual:
303             dump = black.dump_to_file(actual)
304             msg = (
305                 f"Expected diff isn't equal to the actual. If you made changes "
306                 f"to expression.py and this is an anticipated difference, "
307                 f"overwrite tests/data/expression.diff with {dump}"
308             )
309             self.assertEqual(expected, actual, msg)
310
311     @patch("black.dump_to_file", dump_to_stderr)
312     def test_fstring(self) -> None:
313         source, expected = read_data("fstring")
314         actual = fs(source)
315         self.assertFormatEqual(expected, actual)
316         black.assert_equivalent(source, actual)
317         black.assert_stable(source, actual, black.FileMode())
318
319     @patch("black.dump_to_file", dump_to_stderr)
320     def test_string_quotes(self) -> None:
321         source, expected = read_data("string_quotes")
322         actual = fs(source)
323         self.assertFormatEqual(expected, actual)
324         black.assert_equivalent(source, actual)
325         black.assert_stable(source, actual, black.FileMode())
326         mode = black.FileMode(string_normalization=False)
327         not_normalized = fs(source, mode=mode)
328         self.assertFormatEqual(source, not_normalized)
329         black.assert_equivalent(source, not_normalized)
330         black.assert_stable(source, not_normalized, mode=mode)
331
332     @patch("black.dump_to_file", dump_to_stderr)
333     def test_slices(self) -> None:
334         source, expected = read_data("slices")
335         actual = fs(source)
336         self.assertFormatEqual(expected, actual)
337         black.assert_equivalent(source, actual)
338         black.assert_stable(source, actual, black.FileMode())
339
340     @patch("black.dump_to_file", dump_to_stderr)
341     def test_comments(self) -> None:
342         source, expected = read_data("comments")
343         actual = fs(source)
344         self.assertFormatEqual(expected, actual)
345         black.assert_equivalent(source, actual)
346         black.assert_stable(source, actual, black.FileMode())
347
348     @patch("black.dump_to_file", dump_to_stderr)
349     def test_comments2(self) -> None:
350         source, expected = read_data("comments2")
351         actual = fs(source)
352         self.assertFormatEqual(expected, actual)
353         black.assert_equivalent(source, actual)
354         black.assert_stable(source, actual, black.FileMode())
355
356     @patch("black.dump_to_file", dump_to_stderr)
357     def test_comments3(self) -> None:
358         source, expected = read_data("comments3")
359         actual = fs(source)
360         self.assertFormatEqual(expected, actual)
361         black.assert_equivalent(source, actual)
362         black.assert_stable(source, actual, black.FileMode())
363
364     @patch("black.dump_to_file", dump_to_stderr)
365     def test_comments4(self) -> None:
366         source, expected = read_data("comments4")
367         actual = fs(source)
368         self.assertFormatEqual(expected, actual)
369         black.assert_equivalent(source, actual)
370         black.assert_stable(source, actual, black.FileMode())
371
372     @patch("black.dump_to_file", dump_to_stderr)
373     def test_comments5(self) -> None:
374         source, expected = read_data("comments5")
375         actual = fs(source)
376         self.assertFormatEqual(expected, actual)
377         black.assert_equivalent(source, actual)
378         black.assert_stable(source, actual, black.FileMode())
379
380     @patch("black.dump_to_file", dump_to_stderr)
381     def test_comments6(self) -> None:
382         source, expected = read_data("comments6")
383         actual = fs(source)
384         self.assertFormatEqual(expected, actual)
385         black.assert_equivalent(source, actual)
386         black.assert_stable(source, actual, black.FileMode())
387
388     @patch("black.dump_to_file", dump_to_stderr)
389     def test_cantfit(self) -> None:
390         source, expected = read_data("cantfit")
391         actual = fs(source)
392         self.assertFormatEqual(expected, actual)
393         black.assert_equivalent(source, actual)
394         black.assert_stable(source, actual, black.FileMode())
395
396     @patch("black.dump_to_file", dump_to_stderr)
397     def test_import_spacing(self) -> None:
398         source, expected = read_data("import_spacing")
399         actual = fs(source)
400         self.assertFormatEqual(expected, actual)
401         black.assert_equivalent(source, actual)
402         black.assert_stable(source, actual, black.FileMode())
403
404     @patch("black.dump_to_file", dump_to_stderr)
405     def test_composition(self) -> None:
406         source, expected = read_data("composition")
407         actual = fs(source)
408         self.assertFormatEqual(expected, actual)
409         black.assert_equivalent(source, actual)
410         black.assert_stable(source, actual, black.FileMode())
411
412     @patch("black.dump_to_file", dump_to_stderr)
413     def test_empty_lines(self) -> None:
414         source, expected = read_data("empty_lines")
415         actual = fs(source)
416         self.assertFormatEqual(expected, actual)
417         black.assert_equivalent(source, actual)
418         black.assert_stable(source, actual, black.FileMode())
419
420     @patch("black.dump_to_file", dump_to_stderr)
421     def test_string_prefixes(self) -> None:
422         source, expected = read_data("string_prefixes")
423         actual = fs(source)
424         self.assertFormatEqual(expected, actual)
425         black.assert_equivalent(source, actual)
426         black.assert_stable(source, actual, black.FileMode())
427
428     @patch("black.dump_to_file", dump_to_stderr)
429     def test_numeric_literals(self) -> None:
430         source, expected = read_data("numeric_literals")
431         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
432         actual = fs(source, mode=mode)
433         self.assertFormatEqual(expected, actual)
434         black.assert_equivalent(source, actual)
435         black.assert_stable(source, actual, mode)
436
437     @patch("black.dump_to_file", dump_to_stderr)
438     def test_numeric_literals_ignoring_underscores(self) -> None:
439         source, expected = read_data("numeric_literals_skip_underscores")
440         mode = black.FileMode(target_versions=black.PY36_VERSIONS)
441         actual = fs(source, mode=mode)
442         self.assertFormatEqual(expected, actual)
443         black.assert_equivalent(source, actual)
444         black.assert_stable(source, actual, mode)
445
446     @patch("black.dump_to_file", dump_to_stderr)
447     def test_numeric_literals_py2(self) -> None:
448         source, expected = read_data("numeric_literals_py2")
449         actual = fs(source)
450         self.assertFormatEqual(expected, actual)
451         black.assert_stable(source, actual, black.FileMode())
452
453     @patch("black.dump_to_file", dump_to_stderr)
454     def test_python2(self) -> None:
455         source, expected = read_data("python2")
456         actual = fs(source)
457         self.assertFormatEqual(expected, actual)
458         # black.assert_equivalent(source, actual)
459         black.assert_stable(source, actual, black.FileMode())
460
461     @patch("black.dump_to_file", dump_to_stderr)
462     def test_python2_unicode_literals(self) -> None:
463         source, expected = read_data("python2_unicode_literals")
464         actual = fs(source)
465         self.assertFormatEqual(expected, actual)
466         black.assert_stable(source, actual, black.FileMode())
467
468     @patch("black.dump_to_file", dump_to_stderr)
469     def test_stub(self) -> None:
470         mode = black.FileMode(is_pyi=True)
471         source, expected = read_data("stub.pyi")
472         actual = fs(source, mode=mode)
473         self.assertFormatEqual(expected, actual)
474         black.assert_stable(source, actual, mode)
475
476     @patch("black.dump_to_file", dump_to_stderr)
477     def test_python37(self) -> None:
478         source, expected = read_data("python37")
479         actual = fs(source)
480         self.assertFormatEqual(expected, actual)
481         major, minor = sys.version_info[:2]
482         if major > 3 or (major == 3 and minor >= 7):
483             black.assert_equivalent(source, actual)
484         black.assert_stable(source, actual, black.FileMode())
485
486     @patch("black.dump_to_file", dump_to_stderr)
487     def test_fmtonoff(self) -> None:
488         source, expected = read_data("fmtonoff")
489         actual = fs(source)
490         self.assertFormatEqual(expected, actual)
491         black.assert_equivalent(source, actual)
492         black.assert_stable(source, actual, black.FileMode())
493
494     @patch("black.dump_to_file", dump_to_stderr)
495     def test_fmtonoff2(self) -> None:
496         source, expected = read_data("fmtonoff2")
497         actual = fs(source)
498         self.assertFormatEqual(expected, actual)
499         black.assert_equivalent(source, actual)
500         black.assert_stable(source, actual, black.FileMode())
501
502     @patch("black.dump_to_file", dump_to_stderr)
503     def test_remove_empty_parentheses_after_class(self) -> None:
504         source, expected = read_data("class_blank_parentheses")
505         actual = fs(source)
506         self.assertFormatEqual(expected, actual)
507         black.assert_equivalent(source, actual)
508         black.assert_stable(source, actual, black.FileMode())
509
510     @patch("black.dump_to_file", dump_to_stderr)
511     def test_new_line_between_class_and_code(self) -> None:
512         source, expected = read_data("class_methods_new_line")
513         actual = fs(source)
514         self.assertFormatEqual(expected, actual)
515         black.assert_equivalent(source, actual)
516         black.assert_stable(source, actual, black.FileMode())
517
518     @patch("black.dump_to_file", dump_to_stderr)
519     def test_bracket_match(self) -> None:
520         source, expected = read_data("bracketmatch")
521         actual = fs(source)
522         self.assertFormatEqual(expected, actual)
523         black.assert_equivalent(source, actual)
524         black.assert_stable(source, actual, black.FileMode())
525
526     def test_comment_indentation(self) -> None:
527         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t# comment\n\tpass\n"
528         contents_spc = "if 1:\n    if 2:\n        pass\n    # comment\n    pass\n"
529
530         self.assertFormatEqual(fs(contents_spc), contents_spc)
531         self.assertFormatEqual(fs(contents_tab), contents_spc)
532
533         contents_tab = "if 1:\n\tif 2:\n\t\tpass\n\t\t# comment\n\tpass\n"
534         contents_spc = "if 1:\n    if 2:\n        pass\n        # comment\n    pass\n"
535
536         self.assertFormatEqual(fs(contents_tab), contents_spc)
537         self.assertFormatEqual(fs(contents_spc), contents_spc)
538
539     def test_report_verbose(self) -> None:
540         report = black.Report(verbose=True)
541         out_lines = []
542         err_lines = []
543
544         def out(msg: str, **kwargs: Any) -> None:
545             out_lines.append(msg)
546
547         def err(msg: str, **kwargs: Any) -> None:
548             err_lines.append(msg)
549
550         with patch("black.out", out), patch("black.err", err):
551             report.done(Path("f1"), black.Changed.NO)
552             self.assertEqual(len(out_lines), 1)
553             self.assertEqual(len(err_lines), 0)
554             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
555             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
556             self.assertEqual(report.return_code, 0)
557             report.done(Path("f2"), black.Changed.YES)
558             self.assertEqual(len(out_lines), 2)
559             self.assertEqual(len(err_lines), 0)
560             self.assertEqual(out_lines[-1], "reformatted f2")
561             self.assertEqual(
562                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
563             )
564             report.done(Path("f3"), black.Changed.CACHED)
565             self.assertEqual(len(out_lines), 3)
566             self.assertEqual(len(err_lines), 0)
567             self.assertEqual(
568                 out_lines[-1], "f3 wasn't modified on disk since last run."
569             )
570             self.assertEqual(
571                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
572             )
573             self.assertEqual(report.return_code, 0)
574             report.check = True
575             self.assertEqual(report.return_code, 1)
576             report.check = False
577             report.failed(Path("e1"), "boom")
578             self.assertEqual(len(out_lines), 3)
579             self.assertEqual(len(err_lines), 1)
580             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
581             self.assertEqual(
582                 unstyle(str(report)),
583                 "1 file reformatted, 2 files left unchanged, "
584                 "1 file failed to reformat.",
585             )
586             self.assertEqual(report.return_code, 123)
587             report.done(Path("f3"), black.Changed.YES)
588             self.assertEqual(len(out_lines), 4)
589             self.assertEqual(len(err_lines), 1)
590             self.assertEqual(out_lines[-1], "reformatted f3")
591             self.assertEqual(
592                 unstyle(str(report)),
593                 "2 files reformatted, 2 files left unchanged, "
594                 "1 file failed to reformat.",
595             )
596             self.assertEqual(report.return_code, 123)
597             report.failed(Path("e2"), "boom")
598             self.assertEqual(len(out_lines), 4)
599             self.assertEqual(len(err_lines), 2)
600             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
601             self.assertEqual(
602                 unstyle(str(report)),
603                 "2 files reformatted, 2 files left unchanged, "
604                 "2 files failed to reformat.",
605             )
606             self.assertEqual(report.return_code, 123)
607             report.path_ignored(Path("wat"), "no match")
608             self.assertEqual(len(out_lines), 5)
609             self.assertEqual(len(err_lines), 2)
610             self.assertEqual(out_lines[-1], "wat ignored: no match")
611             self.assertEqual(
612                 unstyle(str(report)),
613                 "2 files reformatted, 2 files left unchanged, "
614                 "2 files failed to reformat.",
615             )
616             self.assertEqual(report.return_code, 123)
617             report.done(Path("f4"), black.Changed.NO)
618             self.assertEqual(len(out_lines), 6)
619             self.assertEqual(len(err_lines), 2)
620             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
621             self.assertEqual(
622                 unstyle(str(report)),
623                 "2 files reformatted, 3 files left unchanged, "
624                 "2 files failed to reformat.",
625             )
626             self.assertEqual(report.return_code, 123)
627             report.check = True
628             self.assertEqual(
629                 unstyle(str(report)),
630                 "2 files would be reformatted, 3 files would be left unchanged, "
631                 "2 files would fail to reformat.",
632             )
633
634     def test_report_quiet(self) -> None:
635         report = black.Report(quiet=True)
636         out_lines = []
637         err_lines = []
638
639         def out(msg: str, **kwargs: Any) -> None:
640             out_lines.append(msg)
641
642         def err(msg: str, **kwargs: Any) -> None:
643             err_lines.append(msg)
644
645         with patch("black.out", out), patch("black.err", err):
646             report.done(Path("f1"), black.Changed.NO)
647             self.assertEqual(len(out_lines), 0)
648             self.assertEqual(len(err_lines), 0)
649             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
650             self.assertEqual(report.return_code, 0)
651             report.done(Path("f2"), black.Changed.YES)
652             self.assertEqual(len(out_lines), 0)
653             self.assertEqual(len(err_lines), 0)
654             self.assertEqual(
655                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
656             )
657             report.done(Path("f3"), black.Changed.CACHED)
658             self.assertEqual(len(out_lines), 0)
659             self.assertEqual(len(err_lines), 0)
660             self.assertEqual(
661                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
662             )
663             self.assertEqual(report.return_code, 0)
664             report.check = True
665             self.assertEqual(report.return_code, 1)
666             report.check = False
667             report.failed(Path("e1"), "boom")
668             self.assertEqual(len(out_lines), 0)
669             self.assertEqual(len(err_lines), 1)
670             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
671             self.assertEqual(
672                 unstyle(str(report)),
673                 "1 file reformatted, 2 files left unchanged, "
674                 "1 file failed to reformat.",
675             )
676             self.assertEqual(report.return_code, 123)
677             report.done(Path("f3"), black.Changed.YES)
678             self.assertEqual(len(out_lines), 0)
679             self.assertEqual(len(err_lines), 1)
680             self.assertEqual(
681                 unstyle(str(report)),
682                 "2 files reformatted, 2 files left unchanged, "
683                 "1 file failed to reformat.",
684             )
685             self.assertEqual(report.return_code, 123)
686             report.failed(Path("e2"), "boom")
687             self.assertEqual(len(out_lines), 0)
688             self.assertEqual(len(err_lines), 2)
689             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
690             self.assertEqual(
691                 unstyle(str(report)),
692                 "2 files reformatted, 2 files left unchanged, "
693                 "2 files failed to reformat.",
694             )
695             self.assertEqual(report.return_code, 123)
696             report.path_ignored(Path("wat"), "no match")
697             self.assertEqual(len(out_lines), 0)
698             self.assertEqual(len(err_lines), 2)
699             self.assertEqual(
700                 unstyle(str(report)),
701                 "2 files reformatted, 2 files left unchanged, "
702                 "2 files failed to reformat.",
703             )
704             self.assertEqual(report.return_code, 123)
705             report.done(Path("f4"), black.Changed.NO)
706             self.assertEqual(len(out_lines), 0)
707             self.assertEqual(len(err_lines), 2)
708             self.assertEqual(
709                 unstyle(str(report)),
710                 "2 files reformatted, 3 files left unchanged, "
711                 "2 files failed to reformat.",
712             )
713             self.assertEqual(report.return_code, 123)
714             report.check = True
715             self.assertEqual(
716                 unstyle(str(report)),
717                 "2 files would be reformatted, 3 files would be left unchanged, "
718                 "2 files would fail to reformat.",
719             )
720
721     def test_report_normal(self) -> None:
722         report = black.Report()
723         out_lines = []
724         err_lines = []
725
726         def out(msg: str, **kwargs: Any) -> None:
727             out_lines.append(msg)
728
729         def err(msg: str, **kwargs: Any) -> None:
730             err_lines.append(msg)
731
732         with patch("black.out", out), patch("black.err", err):
733             report.done(Path("f1"), black.Changed.NO)
734             self.assertEqual(len(out_lines), 0)
735             self.assertEqual(len(err_lines), 0)
736             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
737             self.assertEqual(report.return_code, 0)
738             report.done(Path("f2"), black.Changed.YES)
739             self.assertEqual(len(out_lines), 1)
740             self.assertEqual(len(err_lines), 0)
741             self.assertEqual(out_lines[-1], "reformatted f2")
742             self.assertEqual(
743                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
744             )
745             report.done(Path("f3"), black.Changed.CACHED)
746             self.assertEqual(len(out_lines), 1)
747             self.assertEqual(len(err_lines), 0)
748             self.assertEqual(out_lines[-1], "reformatted f2")
749             self.assertEqual(
750                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
751             )
752             self.assertEqual(report.return_code, 0)
753             report.check = True
754             self.assertEqual(report.return_code, 1)
755             report.check = False
756             report.failed(Path("e1"), "boom")
757             self.assertEqual(len(out_lines), 1)
758             self.assertEqual(len(err_lines), 1)
759             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
760             self.assertEqual(
761                 unstyle(str(report)),
762                 "1 file reformatted, 2 files left unchanged, "
763                 "1 file failed to reformat.",
764             )
765             self.assertEqual(report.return_code, 123)
766             report.done(Path("f3"), black.Changed.YES)
767             self.assertEqual(len(out_lines), 2)
768             self.assertEqual(len(err_lines), 1)
769             self.assertEqual(out_lines[-1], "reformatted f3")
770             self.assertEqual(
771                 unstyle(str(report)),
772                 "2 files reformatted, 2 files left unchanged, "
773                 "1 file failed to reformat.",
774             )
775             self.assertEqual(report.return_code, 123)
776             report.failed(Path("e2"), "boom")
777             self.assertEqual(len(out_lines), 2)
778             self.assertEqual(len(err_lines), 2)
779             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
780             self.assertEqual(
781                 unstyle(str(report)),
782                 "2 files reformatted, 2 files left unchanged, "
783                 "2 files failed to reformat.",
784             )
785             self.assertEqual(report.return_code, 123)
786             report.path_ignored(Path("wat"), "no match")
787             self.assertEqual(len(out_lines), 2)
788             self.assertEqual(len(err_lines), 2)
789             self.assertEqual(
790                 unstyle(str(report)),
791                 "2 files reformatted, 2 files left unchanged, "
792                 "2 files failed to reformat.",
793             )
794             self.assertEqual(report.return_code, 123)
795             report.done(Path("f4"), black.Changed.NO)
796             self.assertEqual(len(out_lines), 2)
797             self.assertEqual(len(err_lines), 2)
798             self.assertEqual(
799                 unstyle(str(report)),
800                 "2 files reformatted, 3 files left unchanged, "
801                 "2 files failed to reformat.",
802             )
803             self.assertEqual(report.return_code, 123)
804             report.check = True
805             self.assertEqual(
806                 unstyle(str(report)),
807                 "2 files would be reformatted, 3 files would be left unchanged, "
808                 "2 files would fail to reformat.",
809             )
810
811     def test_get_features_used(self) -> None:
812         node = black.lib2to3_parse("def f(*, arg): ...\n")
813         self.assertEqual(black.get_features_used(node), set())
814         node = black.lib2to3_parse("def f(*, arg,): ...\n")
815         self.assertEqual(black.get_features_used(node), {Feature.TRAILING_COMMA})
816         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
817         self.assertEqual(black.get_features_used(node), {Feature.F_STRINGS})
818         node = black.lib2to3_parse("123_456\n")
819         self.assertEqual(black.get_features_used(node), {Feature.NUMERIC_UNDERSCORES})
820         node = black.lib2to3_parse("123456\n")
821         self.assertEqual(black.get_features_used(node), set())
822         source, expected = read_data("function")
823         node = black.lib2to3_parse(source)
824         self.assertEqual(
825             black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
826         )
827         node = black.lib2to3_parse(expected)
828         self.assertEqual(
829             black.get_features_used(node), {Feature.TRAILING_COMMA, Feature.F_STRINGS}
830         )
831         source, expected = read_data("expression")
832         node = black.lib2to3_parse(source)
833         self.assertEqual(black.get_features_used(node), set())
834         node = black.lib2to3_parse(expected)
835         self.assertEqual(black.get_features_used(node), set())
836
837     def test_get_future_imports(self) -> None:
838         node = black.lib2to3_parse("\n")
839         self.assertEqual(set(), black.get_future_imports(node))
840         node = black.lib2to3_parse("from __future__ import black\n")
841         self.assertEqual({"black"}, black.get_future_imports(node))
842         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
843         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
844         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
845         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
846         node = black.lib2to3_parse(
847             "from __future__ import multiple\nfrom __future__ import imports\n"
848         )
849         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
850         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
851         self.assertEqual({"black"}, black.get_future_imports(node))
852         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
853         self.assertEqual({"black"}, black.get_future_imports(node))
854         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
855         self.assertEqual(set(), black.get_future_imports(node))
856         node = black.lib2to3_parse("from some.module import black\n")
857         self.assertEqual(set(), black.get_future_imports(node))
858         node = black.lib2to3_parse(
859             "from __future__ import unicode_literals as _unicode_literals"
860         )
861         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
862         node = black.lib2to3_parse(
863             "from __future__ import unicode_literals as _lol, print"
864         )
865         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
866
867     def test_debug_visitor(self) -> None:
868         source, _ = read_data("debug_visitor.py")
869         expected, _ = read_data("debug_visitor.out")
870         out_lines = []
871         err_lines = []
872
873         def out(msg: str, **kwargs: Any) -> None:
874             out_lines.append(msg)
875
876         def err(msg: str, **kwargs: Any) -> None:
877             err_lines.append(msg)
878
879         with patch("black.out", out), patch("black.err", err):
880             black.DebugVisitor.show(source)
881         actual = "\n".join(out_lines) + "\n"
882         log_name = ""
883         if expected != actual:
884             log_name = black.dump_to_file(*out_lines)
885         self.assertEqual(
886             expected,
887             actual,
888             f"AST print out is different. Actual version dumped to {log_name}",
889         )
890
891     def test_format_file_contents(self) -> None:
892         empty = ""
893         mode = black.FileMode()
894         with self.assertRaises(black.NothingChanged):
895             black.format_file_contents(empty, mode=mode, fast=False)
896         just_nl = "\n"
897         with self.assertRaises(black.NothingChanged):
898             black.format_file_contents(just_nl, mode=mode, fast=False)
899         same = "l = [1, 2, 3]\n"
900         with self.assertRaises(black.NothingChanged):
901             black.format_file_contents(same, mode=mode, fast=False)
902         different = "l = [1,2,3]"
903         expected = same
904         actual = black.format_file_contents(different, mode=mode, fast=False)
905         self.assertEqual(expected, actual)
906         invalid = "return if you can"
907         with self.assertRaises(black.InvalidInput) as e:
908             black.format_file_contents(invalid, mode=mode, fast=False)
909         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
910
911     def test_endmarker(self) -> None:
912         n = black.lib2to3_parse("\n")
913         self.assertEqual(n.type, black.syms.file_input)
914         self.assertEqual(len(n.children), 1)
915         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
916
917     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
918     def test_assertFormatEqual(self) -> None:
919         out_lines = []
920         err_lines = []
921
922         def out(msg: str, **kwargs: Any) -> None:
923             out_lines.append(msg)
924
925         def err(msg: str, **kwargs: Any) -> None:
926             err_lines.append(msg)
927
928         with patch("black.out", out), patch("black.err", err):
929             with self.assertRaises(AssertionError):
930                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
931
932         out_str = "".join(out_lines)
933         self.assertTrue("Expected tree:" in out_str)
934         self.assertTrue("Actual tree:" in out_str)
935         self.assertEqual("".join(err_lines), "")
936
937     def test_cache_broken_file(self) -> None:
938         mode = black.FileMode()
939         with cache_dir() as workspace:
940             cache_file = black.get_cache_file(mode)
941             with cache_file.open("w") as fobj:
942                 fobj.write("this is not a pickle")
943             self.assertEqual(black.read_cache(mode), {})
944             src = (workspace / "test.py").resolve()
945             with src.open("w") as fobj:
946                 fobj.write("print('hello')")
947             self.invokeBlack([str(src)])
948             cache = black.read_cache(mode)
949             self.assertIn(src, cache)
950
951     def test_cache_single_file_already_cached(self) -> None:
952         mode = black.FileMode()
953         with cache_dir() as workspace:
954             src = (workspace / "test.py").resolve()
955             with src.open("w") as fobj:
956                 fobj.write("print('hello')")
957             black.write_cache({}, [src], mode)
958             self.invokeBlack([str(src)])
959             with src.open("r") as fobj:
960                 self.assertEqual(fobj.read(), "print('hello')")
961
962     @event_loop(close=False)
963     def test_cache_multiple_files(self) -> None:
964         mode = black.FileMode()
965         with cache_dir() as workspace, patch(
966             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
967         ):
968             one = (workspace / "one.py").resolve()
969             with one.open("w") as fobj:
970                 fobj.write("print('hello')")
971             two = (workspace / "two.py").resolve()
972             with two.open("w") as fobj:
973                 fobj.write("print('hello')")
974             black.write_cache({}, [one], mode)
975             self.invokeBlack([str(workspace)])
976             with one.open("r") as fobj:
977                 self.assertEqual(fobj.read(), "print('hello')")
978             with two.open("r") as fobj:
979                 self.assertEqual(fobj.read(), 'print("hello")\n')
980             cache = black.read_cache(mode)
981             self.assertIn(one, cache)
982             self.assertIn(two, cache)
983
984     def test_no_cache_when_writeback_diff(self) -> None:
985         mode = black.FileMode()
986         with cache_dir() as workspace:
987             src = (workspace / "test.py").resolve()
988             with src.open("w") as fobj:
989                 fobj.write("print('hello')")
990             self.invokeBlack([str(src), "--diff"])
991             cache_file = black.get_cache_file(mode)
992             self.assertFalse(cache_file.exists())
993
994     def test_no_cache_when_stdin(self) -> None:
995         mode = black.FileMode()
996         with cache_dir():
997             result = CliRunner().invoke(
998                 black.main, ["-"], input=BytesIO(b"print('hello')")
999             )
1000             self.assertEqual(result.exit_code, 0)
1001             cache_file = black.get_cache_file(mode)
1002             self.assertFalse(cache_file.exists())
1003
1004     def test_read_cache_no_cachefile(self) -> None:
1005         mode = black.FileMode()
1006         with cache_dir():
1007             self.assertEqual(black.read_cache(mode), {})
1008
1009     def test_write_cache_read_cache(self) -> None:
1010         mode = black.FileMode()
1011         with cache_dir() as workspace:
1012             src = (workspace / "test.py").resolve()
1013             src.touch()
1014             black.write_cache({}, [src], mode)
1015             cache = black.read_cache(mode)
1016             self.assertIn(src, cache)
1017             self.assertEqual(cache[src], black.get_cache_info(src))
1018
1019     def test_filter_cached(self) -> None:
1020         with TemporaryDirectory() as workspace:
1021             path = Path(workspace)
1022             uncached = (path / "uncached").resolve()
1023             cached = (path / "cached").resolve()
1024             cached_but_changed = (path / "changed").resolve()
1025             uncached.touch()
1026             cached.touch()
1027             cached_but_changed.touch()
1028             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
1029             todo, done = black.filter_cached(
1030                 cache, {uncached, cached, cached_but_changed}
1031             )
1032             self.assertEqual(todo, {uncached, cached_but_changed})
1033             self.assertEqual(done, {cached})
1034
1035     def test_write_cache_creates_directory_if_needed(self) -> None:
1036         mode = black.FileMode()
1037         with cache_dir(exists=False) as workspace:
1038             self.assertFalse(workspace.exists())
1039             black.write_cache({}, [], mode)
1040             self.assertTrue(workspace.exists())
1041
1042     @event_loop(close=False)
1043     def test_failed_formatting_does_not_get_cached(self) -> None:
1044         mode = black.FileMode()
1045         with cache_dir() as workspace, patch(
1046             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1047         ):
1048             failing = (workspace / "failing.py").resolve()
1049             with failing.open("w") as fobj:
1050                 fobj.write("not actually python")
1051             clean = (workspace / "clean.py").resolve()
1052             with clean.open("w") as fobj:
1053                 fobj.write('print("hello")\n')
1054             self.invokeBlack([str(workspace)], exit_code=123)
1055             cache = black.read_cache(mode)
1056             self.assertNotIn(failing, cache)
1057             self.assertIn(clean, cache)
1058
1059     def test_write_cache_write_fail(self) -> None:
1060         mode = black.FileMode()
1061         with cache_dir(), patch.object(Path, "open") as mock:
1062             mock.side_effect = OSError
1063             black.write_cache({}, [], mode)
1064
1065     @event_loop(close=False)
1066     def test_check_diff_use_together(self) -> None:
1067         with cache_dir():
1068             # Files which will be reformatted.
1069             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1070             self.invokeBlack([str(src1), "--diff", "--check"], exit_code=1)
1071             # Files which will not be reformatted.
1072             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1073             self.invokeBlack([str(src2), "--diff", "--check"])
1074             # Multi file command.
1075             self.invokeBlack([str(src1), str(src2), "--diff", "--check"], exit_code=1)
1076
1077     def test_no_files(self) -> None:
1078         with cache_dir():
1079             # Without an argument, black exits with error code 0.
1080             self.invokeBlack([])
1081
1082     def test_broken_symlink(self) -> None:
1083         with cache_dir() as workspace:
1084             symlink = workspace / "broken_link.py"
1085             try:
1086                 symlink.symlink_to("nonexistent.py")
1087             except OSError as e:
1088                 self.skipTest(f"Can't create symlinks: {e}")
1089             self.invokeBlack([str(workspace.resolve())])
1090
1091     def test_read_cache_line_lengths(self) -> None:
1092         mode = black.FileMode()
1093         short_mode = black.FileMode(line_length=1)
1094         with cache_dir() as workspace:
1095             path = (workspace / "file.py").resolve()
1096             path.touch()
1097             black.write_cache({}, [path], mode)
1098             one = black.read_cache(mode)
1099             self.assertIn(path, one)
1100             two = black.read_cache(short_mode)
1101             self.assertNotIn(path, two)
1102
1103     def test_single_file_force_pyi(self) -> None:
1104         reg_mode = black.FileMode()
1105         pyi_mode = black.FileMode(is_pyi=True)
1106         contents, expected = read_data("force_pyi")
1107         with cache_dir() as workspace:
1108             path = (workspace / "file.py").resolve()
1109             with open(path, "w") as fh:
1110                 fh.write(contents)
1111             self.invokeBlack([str(path), "--pyi"])
1112             with open(path, "r") as fh:
1113                 actual = fh.read()
1114             # verify cache with --pyi is separate
1115             pyi_cache = black.read_cache(pyi_mode)
1116             self.assertIn(path, pyi_cache)
1117             normal_cache = black.read_cache(reg_mode)
1118             self.assertNotIn(path, normal_cache)
1119         self.assertEqual(actual, expected)
1120
1121     @event_loop(close=False)
1122     def test_multi_file_force_pyi(self) -> None:
1123         reg_mode = black.FileMode()
1124         pyi_mode = black.FileMode(is_pyi=True)
1125         contents, expected = read_data("force_pyi")
1126         with cache_dir() as workspace:
1127             paths = [
1128                 (workspace / "file1.py").resolve(),
1129                 (workspace / "file2.py").resolve(),
1130             ]
1131             for path in paths:
1132                 with open(path, "w") as fh:
1133                     fh.write(contents)
1134             self.invokeBlack([str(p) for p in paths] + ["--pyi"])
1135             for path in paths:
1136                 with open(path, "r") as fh:
1137                     actual = fh.read()
1138                 self.assertEqual(actual, expected)
1139             # verify cache with --pyi is separate
1140             pyi_cache = black.read_cache(pyi_mode)
1141             normal_cache = black.read_cache(reg_mode)
1142             for path in paths:
1143                 self.assertIn(path, pyi_cache)
1144                 self.assertNotIn(path, normal_cache)
1145
1146     def test_pipe_force_pyi(self) -> None:
1147         source, expected = read_data("force_pyi")
1148         result = CliRunner().invoke(
1149             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1150         )
1151         self.assertEqual(result.exit_code, 0)
1152         actual = result.output
1153         self.assertFormatEqual(actual, expected)
1154
1155     def test_single_file_force_py36(self) -> None:
1156         reg_mode = black.FileMode()
1157         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1158         source, expected = read_data("force_py36")
1159         with cache_dir() as workspace:
1160             path = (workspace / "file.py").resolve()
1161             with open(path, "w") as fh:
1162                 fh.write(source)
1163             self.invokeBlack([str(path), "--py36"])
1164             with open(path, "r") as fh:
1165                 actual = fh.read()
1166             # verify cache with --py36 is separate
1167             py36_cache = black.read_cache(py36_mode)
1168             self.assertIn(path, py36_cache)
1169             normal_cache = black.read_cache(reg_mode)
1170             self.assertNotIn(path, normal_cache)
1171         self.assertEqual(actual, expected)
1172
1173     @event_loop(close=False)
1174     def test_multi_file_force_py36(self) -> None:
1175         reg_mode = black.FileMode()
1176         py36_mode = black.FileMode(target_versions=black.PY36_VERSIONS)
1177         source, expected = read_data("force_py36")
1178         with cache_dir() as workspace:
1179             paths = [
1180                 (workspace / "file1.py").resolve(),
1181                 (workspace / "file2.py").resolve(),
1182             ]
1183             for path in paths:
1184                 with open(path, "w") as fh:
1185                     fh.write(source)
1186             self.invokeBlack([str(p) for p in paths] + ["--py36"])
1187             for path in paths:
1188                 with open(path, "r") as fh:
1189                     actual = fh.read()
1190                 self.assertEqual(actual, expected)
1191             # verify cache with --py36 is separate
1192             pyi_cache = black.read_cache(py36_mode)
1193             normal_cache = black.read_cache(reg_mode)
1194             for path in paths:
1195                 self.assertIn(path, pyi_cache)
1196                 self.assertNotIn(path, normal_cache)
1197
1198     def test_pipe_force_py36(self) -> None:
1199         source, expected = read_data("force_py36")
1200         result = CliRunner().invoke(
1201             black.main, ["-", "-q", "--py36"], input=BytesIO(source.encode("utf8"))
1202         )
1203         self.assertEqual(result.exit_code, 0)
1204         actual = result.output
1205         self.assertFormatEqual(actual, expected)
1206
1207     def test_include_exclude(self) -> None:
1208         path = THIS_DIR / "data" / "include_exclude_tests"
1209         include = re.compile(r"\.pyi?$")
1210         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1211         report = black.Report()
1212         sources: List[Path] = []
1213         expected = [
1214             Path(path / "b/dont_exclude/a.py"),
1215             Path(path / "b/dont_exclude/a.pyi"),
1216         ]
1217         this_abs = THIS_DIR.resolve()
1218         sources.extend(
1219             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1220         )
1221         self.assertEqual(sorted(expected), sorted(sources))
1222
1223     def test_empty_include(self) -> None:
1224         path = THIS_DIR / "data" / "include_exclude_tests"
1225         report = black.Report()
1226         empty = re.compile(r"")
1227         sources: List[Path] = []
1228         expected = [
1229             Path(path / "b/exclude/a.pie"),
1230             Path(path / "b/exclude/a.py"),
1231             Path(path / "b/exclude/a.pyi"),
1232             Path(path / "b/dont_exclude/a.pie"),
1233             Path(path / "b/dont_exclude/a.py"),
1234             Path(path / "b/dont_exclude/a.pyi"),
1235             Path(path / "b/.definitely_exclude/a.pie"),
1236             Path(path / "b/.definitely_exclude/a.py"),
1237             Path(path / "b/.definitely_exclude/a.pyi"),
1238         ]
1239         this_abs = THIS_DIR.resolve()
1240         sources.extend(
1241             black.gen_python_files_in_dir(
1242                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1243             )
1244         )
1245         self.assertEqual(sorted(expected), sorted(sources))
1246
1247     def test_empty_exclude(self) -> None:
1248         path = THIS_DIR / "data" / "include_exclude_tests"
1249         report = black.Report()
1250         empty = re.compile(r"")
1251         sources: List[Path] = []
1252         expected = [
1253             Path(path / "b/dont_exclude/a.py"),
1254             Path(path / "b/dont_exclude/a.pyi"),
1255             Path(path / "b/exclude/a.py"),
1256             Path(path / "b/exclude/a.pyi"),
1257             Path(path / "b/.definitely_exclude/a.py"),
1258             Path(path / "b/.definitely_exclude/a.pyi"),
1259         ]
1260         this_abs = THIS_DIR.resolve()
1261         sources.extend(
1262             black.gen_python_files_in_dir(
1263                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1264             )
1265         )
1266         self.assertEqual(sorted(expected), sorted(sources))
1267
1268     def test_invalid_include_exclude(self) -> None:
1269         for option in ["--include", "--exclude"]:
1270             self.invokeBlack(["-", option, "**()(!!*)"], exit_code=2)
1271
1272     def test_preserves_line_endings(self) -> None:
1273         with TemporaryDirectory() as workspace:
1274             test_file = Path(workspace) / "test.py"
1275             for nl in ["\n", "\r\n"]:
1276                 contents = nl.join(["def f(  ):", "    pass"])
1277                 test_file.write_bytes(contents.encode())
1278                 ff(test_file, write_back=black.WriteBack.YES)
1279                 updated_contents: bytes = test_file.read_bytes()
1280                 self.assertIn(nl.encode(), updated_contents)
1281                 if nl == "\n":
1282                     self.assertNotIn(b"\r\n", updated_contents)
1283
1284     def test_preserves_line_endings_via_stdin(self) -> None:
1285         for nl in ["\n", "\r\n"]:
1286             contents = nl.join(["def f(  ):", "    pass"])
1287             runner = BlackRunner()
1288             result = runner.invoke(
1289                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1290             )
1291             self.assertEqual(result.exit_code, 0)
1292             output = runner.stdout_bytes
1293             self.assertIn(nl.encode("utf8"), output)
1294             if nl == "\n":
1295                 self.assertNotIn(b"\r\n", output)
1296
1297     def test_assert_equivalent_different_asts(self) -> None:
1298         with self.assertRaises(AssertionError):
1299             black.assert_equivalent("{}", "None")
1300
1301     def test_symlink_out_of_root_directory(self) -> None:
1302         path = MagicMock()
1303         root = THIS_DIR
1304         child = MagicMock()
1305         include = re.compile(black.DEFAULT_INCLUDES)
1306         exclude = re.compile(black.DEFAULT_EXCLUDES)
1307         report = black.Report()
1308         # `child` should behave like a symlink which resolved path is clearly
1309         # outside of the `root` directory.
1310         path.iterdir.return_value = [child]
1311         child.resolve.return_value = Path("/a/b/c")
1312         child.is_symlink.return_value = True
1313         try:
1314             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1315         except ValueError as ve:
1316             self.fail(f"`get_python_files_in_dir()` failed: {ve}")
1317         path.iterdir.assert_called_once()
1318         child.resolve.assert_called_once()
1319         child.is_symlink.assert_called_once()
1320         # `child` should behave like a strange file which resolved path is clearly
1321         # outside of the `root` directory.
1322         child.is_symlink.return_value = False
1323         with self.assertRaises(ValueError):
1324             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1325         path.iterdir.assert_called()
1326         self.assertEqual(path.iterdir.call_count, 2)
1327         child.resolve.assert_called()
1328         self.assertEqual(child.resolve.call_count, 2)
1329         child.is_symlink.assert_called()
1330         self.assertEqual(child.is_symlink.call_count, 2)
1331
1332     def test_shhh_click(self) -> None:
1333         try:
1334             from click import _unicodefun  # type: ignore
1335         except ModuleNotFoundError:
1336             self.skipTest("Incompatible Click version")
1337         if not hasattr(_unicodefun, "_verify_python3_env"):
1338             self.skipTest("Incompatible Click version")
1339         # First, let's see if Click is crashing with a preferred ASCII charset.
1340         with patch("locale.getpreferredencoding") as gpe:
1341             gpe.return_value = "ASCII"
1342             with self.assertRaises(RuntimeError):
1343                 _unicodefun._verify_python3_env()
1344         # Now, let's silence Click...
1345         black.patch_click()
1346         # ...and confirm it's silent.
1347         with patch("locale.getpreferredencoding") as gpe:
1348             gpe.return_value = "ASCII"
1349             try:
1350                 _unicodefun._verify_python3_env()
1351             except RuntimeError as re:
1352                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1353
1354     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1355     @async_test
1356     async def test_blackd_request_needs_formatting(self) -> None:
1357         app = blackd.make_app()
1358         async with TestClient(TestServer(app)) as client:
1359             response = await client.post("/", data=b"print('hello world')")
1360             self.assertEqual(response.status, 200)
1361             self.assertEqual(response.charset, "utf8")
1362             self.assertEqual(await response.read(), b'print("hello world")\n')
1363
1364     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1365     @async_test
1366     async def test_blackd_request_no_change(self) -> None:
1367         app = blackd.make_app()
1368         async with TestClient(TestServer(app)) as client:
1369             response = await client.post("/", data=b'print("hello world")\n')
1370             self.assertEqual(response.status, 204)
1371             self.assertEqual(await response.read(), b"")
1372
1373     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1374     @async_test
1375     async def test_blackd_request_syntax_error(self) -> None:
1376         app = blackd.make_app()
1377         async with TestClient(TestServer(app)) as client:
1378             response = await client.post("/", data=b"what even ( is")
1379             self.assertEqual(response.status, 400)
1380             content = await response.text()
1381             self.assertTrue(
1382                 content.startswith("Cannot parse"),
1383                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1384             )
1385
1386     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1387     @async_test
1388     async def test_blackd_unsupported_version(self) -> None:
1389         app = blackd.make_app()
1390         async with TestClient(TestServer(app)) as client:
1391             response = await client.post(
1392                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1393             )
1394             self.assertEqual(response.status, 501)
1395
1396     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1397     @async_test
1398     async def test_blackd_supported_version(self) -> None:
1399         app = blackd.make_app()
1400         async with TestClient(TestServer(app)) as client:
1401             response = await client.post(
1402                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1403             )
1404             self.assertEqual(response.status, 200)
1405
1406     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1407     @async_test
1408     async def test_blackd_invalid_python_variant(self) -> None:
1409         app = blackd.make_app()
1410         async with TestClient(TestServer(app)) as client:
1411
1412             async def check(header_value: str, expected_status: int = 400) -> None:
1413                 response = await client.post(
1414                     "/",
1415                     data=b"what",
1416                     headers={blackd.PYTHON_VARIANT_HEADER: header_value},
1417                 )
1418                 self.assertEqual(response.status, expected_status)
1419
1420             await check("lol")
1421             await check("ruby3.5")
1422             await check("pyi3.6")
1423             await check("cpy1.5")
1424             await check("2.8")
1425             await check("cpy2.8")
1426             await check("3.0")
1427             await check("pypy3.0")
1428             await check("jython3.4")
1429
1430     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1431     @async_test
1432     async def test_blackd_pyi(self) -> None:
1433         app = blackd.make_app()
1434         async with TestClient(TestServer(app)) as client:
1435             source, expected = read_data("stub.pyi")
1436             response = await client.post(
1437                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1438             )
1439             self.assertEqual(response.status, 200)
1440             self.assertEqual(await response.text(), expected)
1441
1442     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1443     @async_test
1444     async def test_blackd_python_variant(self) -> None:
1445         app = blackd.make_app()
1446         code = (
1447             "def f(\n"
1448             "    and_has_a_bunch_of,\n"
1449             "    very_long_arguments_too,\n"
1450             "    and_lots_of_them_as_well_lol,\n"
1451             "    **and_very_long_keyword_arguments\n"
1452             "):\n"
1453             "    pass\n"
1454         )
1455         async with TestClient(TestServer(app)) as client:
1456
1457             async def check(header_value: str, expected_status: int) -> None:
1458                 response = await client.post(
1459                     "/", data=code, headers={blackd.PYTHON_VARIANT_HEADER: header_value}
1460                 )
1461                 self.assertEqual(response.status, expected_status)
1462
1463             await check("3.6", 200)
1464             await check("cpy3.6", 200)
1465             await check("3.5,3.7", 200)
1466             await check("3.5,cpy3.7", 200)
1467
1468             await check("2", 204)
1469             await check("2.7", 204)
1470             await check("cpy2.7", 204)
1471             await check("pypy2.7", 204)
1472             await check("3.4", 204)
1473             await check("cpy3.4", 204)
1474             await check("pypy3.4", 204)
1475
1476     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1477     @async_test
1478     async def test_blackd_fast(self) -> None:
1479         with open(os.devnull, "w") as dn, redirect_stderr(dn):
1480             app = blackd.make_app()
1481             async with TestClient(TestServer(app)) as client:
1482                 response = await client.post("/", data=b"ur'hello'")
1483                 self.assertEqual(response.status, 500)
1484                 self.assertIn("failed to parse source file", await response.text())
1485                 response = await client.post(
1486                     "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1487                 )
1488                 self.assertEqual(response.status, 200)
1489
1490     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1491     @async_test
1492     async def test_blackd_line_length(self) -> None:
1493         app = blackd.make_app()
1494         async with TestClient(TestServer(app)) as client:
1495             response = await client.post(
1496                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1497             )
1498             self.assertEqual(response.status, 200)
1499
1500     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1501     @async_test
1502     async def test_blackd_invalid_line_length(self) -> None:
1503         app = blackd.make_app()
1504         async with TestClient(TestServer(app)) as client:
1505             response = await client.post(
1506                 "/",
1507                 data=b'print("hello")\n',
1508                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1509             )
1510             self.assertEqual(response.status, 400)
1511
1512     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1513     def test_blackd_main(self) -> None:
1514         with patch("blackd.web.run_app"):
1515             result = CliRunner().invoke(blackd.main, [])
1516             if result.exception is not None:
1517                 raise result.exception
1518             self.assertEqual(result.exit_code, 0)
1519
1520
1521 if __name__ == "__main__":
1522     unittest.main(module="test_black")