]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Update Poetry section in pyproject.toml (#490)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial, wraps
6 from io import BytesIO, TextIOWrapper
7 import os
8 from pathlib import Path
9 import re
10 import sys
11 from tempfile import TemporaryDirectory
12 from typing import (
13     Any,
14     BinaryIO,
15     Callable,
16     Coroutine,
17     Generator,
18     List,
19     Tuple,
20     Iterator,
21     TypeVar,
22 )
23 import unittest
24 from unittest.mock import patch, MagicMock
25
26 from click import unstyle
27 from click.testing import CliRunner
28
29 import black
30
31 try:
32     import blackd
33     from aiohttp.test_utils import TestClient, TestServer
34 except ImportError:
35     has_blackd_deps = False
36 else:
37     has_blackd_deps = True
38
39
40 ll = 88
41 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
42 fs = partial(black.format_str, line_length=ll)
43 THIS_FILE = Path(__file__)
44 THIS_DIR = THIS_FILE.parent
45 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
46 T = TypeVar("T")
47 R = TypeVar("R")
48
49
50 def dump_to_stderr(*output: str) -> str:
51     return "\n" + "\n".join(output) + "\n"
52
53
54 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
55     """read_data('test_name') -> 'input', 'output'"""
56     if not name.endswith((".py", ".pyi", ".out", ".diff")):
57         name += ".py"
58     _input: List[str] = []
59     _output: List[str] = []
60     base_dir = THIS_DIR / "data" if data else THIS_DIR
61     with open(base_dir / name, "r", encoding="utf8") as test:
62         lines = test.readlines()
63     result = _input
64     for line in lines:
65         line = line.replace(EMPTY_LINE, "")
66         if line.rstrip() == "# output":
67             result = _output
68             continue
69
70         result.append(line)
71     if _input and not _output:
72         # If there's no output marker, treat the entire file as already pre-formatted.
73         _output = _input[:]
74     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
75
76
77 @contextmanager
78 def cache_dir(exists: bool = True) -> Iterator[Path]:
79     with TemporaryDirectory() as workspace:
80         cache_dir = Path(workspace)
81         if not exists:
82             cache_dir = cache_dir / "new"
83         with patch("black.CACHE_DIR", cache_dir):
84             yield cache_dir
85
86
87 @contextmanager
88 def event_loop(close: bool) -> Iterator[None]:
89     policy = asyncio.get_event_loop_policy()
90     old_loop = policy.get_event_loop()
91     loop = policy.new_event_loop()
92     asyncio.set_event_loop(loop)
93     try:
94         yield
95
96     finally:
97         policy.set_event_loop(old_loop)
98         if close:
99             loop.close()
100
101
102 def async_test(f: Callable[..., Coroutine[Any, None, R]]) -> Callable[..., None]:
103     @event_loop(close=True)
104     @wraps(f)
105     def wrapper(*args: Any, **kwargs: Any) -> None:
106         asyncio.get_event_loop().run_until_complete(f(*args, **kwargs))
107
108     return wrapper
109
110
111 class BlackRunner(CliRunner):
112     """Modify CliRunner so that stderr is not merged with stdout.
113
114     This is a hack that can be removed once we depend on Click 7.x"""
115
116     def __init__(self) -> None:
117         self.stderrbuf = BytesIO()
118         self.stdoutbuf = BytesIO()
119         self.stdout_bytes = b""
120         self.stderr_bytes = b""
121         super().__init__()
122
123     @contextmanager
124     def isolation(self, *args: Any, **kwargs: Any) -> Generator[BinaryIO, None, None]:
125         with super().isolation(*args, **kwargs) as output:
126             try:
127                 hold_stderr = sys.stderr
128                 sys.stderr = TextIOWrapper(self.stderrbuf, encoding=self.charset)
129                 yield output
130             finally:
131                 self.stdout_bytes = sys.stdout.buffer.getvalue()  # type: ignore
132                 self.stderr_bytes = sys.stderr.buffer.getvalue()  # type: ignore
133                 sys.stderr = hold_stderr
134
135
136 class BlackTestCase(unittest.TestCase):
137     maxDiff = None
138
139     def assertFormatEqual(self, expected: str, actual: str) -> None:
140         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
141             bdv: black.DebugVisitor[Any]
142             black.out("Expected tree:", fg="green")
143             try:
144                 exp_node = black.lib2to3_parse(expected)
145                 bdv = black.DebugVisitor()
146                 list(bdv.visit(exp_node))
147             except Exception as ve:
148                 black.err(str(ve))
149             black.out("Actual tree:", fg="red")
150             try:
151                 exp_node = black.lib2to3_parse(actual)
152                 bdv = black.DebugVisitor()
153                 list(bdv.visit(exp_node))
154             except Exception as ve:
155                 black.err(str(ve))
156         self.assertEqual(expected, actual)
157
158     @patch("black.dump_to_file", dump_to_stderr)
159     def test_empty(self) -> None:
160         source = expected = ""
161         actual = fs(source)
162         self.assertFormatEqual(expected, actual)
163         black.assert_equivalent(source, actual)
164         black.assert_stable(source, actual, line_length=ll)
165
166     def test_empty_ff(self) -> None:
167         expected = ""
168         tmp_file = Path(black.dump_to_file())
169         try:
170             self.assertFalse(ff(tmp_file, write_back=black.WriteBack.YES))
171             with open(tmp_file, encoding="utf8") as f:
172                 actual = f.read()
173         finally:
174             os.unlink(tmp_file)
175         self.assertFormatEqual(expected, actual)
176
177     @patch("black.dump_to_file", dump_to_stderr)
178     def test_self(self) -> None:
179         source, expected = read_data("test_black", data=False)
180         actual = fs(source)
181         self.assertFormatEqual(expected, actual)
182         black.assert_equivalent(source, actual)
183         black.assert_stable(source, actual, line_length=ll)
184         self.assertFalse(ff(THIS_FILE))
185
186     @patch("black.dump_to_file", dump_to_stderr)
187     def test_black(self) -> None:
188         source, expected = read_data("../black", data=False)
189         actual = fs(source)
190         self.assertFormatEqual(expected, actual)
191         black.assert_equivalent(source, actual)
192         black.assert_stable(source, actual, line_length=ll)
193         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
194
195     def test_piping(self) -> None:
196         source, expected = read_data("../black", data=False)
197         result = BlackRunner().invoke(
198             black.main,
199             ["-", "--fast", f"--line-length={ll}"],
200             input=BytesIO(source.encode("utf8")),
201         )
202         self.assertEqual(result.exit_code, 0)
203         self.assertFormatEqual(expected, result.output)
204         black.assert_equivalent(source, result.output)
205         black.assert_stable(source, result.output, line_length=ll)
206
207     def test_piping_diff(self) -> None:
208         diff_header = re.compile(
209             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
210             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
211         )
212         source, _ = read_data("expression.py")
213         expected, _ = read_data("expression.diff")
214         config = THIS_DIR / "data" / "empty_pyproject.toml"
215         args = ["-", "--fast", f"--line-length={ll}", "--diff", f"--config={config}"]
216         result = BlackRunner().invoke(
217             black.main, args, input=BytesIO(source.encode("utf8"))
218         )
219         self.assertEqual(result.exit_code, 0)
220         actual = diff_header.sub("[Deterministic header]", result.output)
221         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
222         self.assertEqual(expected, actual)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_setup(self) -> None:
226         source, expected = read_data("../setup", data=False)
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
232
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_function(self) -> None:
235         source, expected = read_data("function")
236         actual = fs(source)
237         self.assertFormatEqual(expected, actual)
238         black.assert_equivalent(source, actual)
239         black.assert_stable(source, actual, line_length=ll)
240
241     @patch("black.dump_to_file", dump_to_stderr)
242     def test_function2(self) -> None:
243         source, expected = read_data("function2")
244         actual = fs(source)
245         self.assertFormatEqual(expected, actual)
246         black.assert_equivalent(source, actual)
247         black.assert_stable(source, actual, line_length=ll)
248
249     @patch("black.dump_to_file", dump_to_stderr)
250     def test_expression(self) -> None:
251         source, expected = read_data("expression")
252         actual = fs(source)
253         self.assertFormatEqual(expected, actual)
254         black.assert_equivalent(source, actual)
255         black.assert_stable(source, actual, line_length=ll)
256
257     def test_expression_ff(self) -> None:
258         source, expected = read_data("expression")
259         tmp_file = Path(black.dump_to_file(source))
260         try:
261             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
262             with open(tmp_file, encoding="utf8") as f:
263                 actual = f.read()
264         finally:
265             os.unlink(tmp_file)
266         self.assertFormatEqual(expected, actual)
267         with patch("black.dump_to_file", dump_to_stderr):
268             black.assert_equivalent(source, actual)
269             black.assert_stable(source, actual, line_length=ll)
270
271     def test_expression_diff(self) -> None:
272         source, _ = read_data("expression.py")
273         expected, _ = read_data("expression.diff")
274         tmp_file = Path(black.dump_to_file(source))
275         diff_header = re.compile(
276             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
277             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
278         )
279         try:
280             result = BlackRunner().invoke(black.main, ["--diff", str(tmp_file)])
281             self.assertEqual(result.exit_code, 0)
282         finally:
283             os.unlink(tmp_file)
284         actual = result.output
285         actual = diff_header.sub("[Deterministic header]", actual)
286         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
287         if expected != actual:
288             dump = black.dump_to_file(actual)
289             msg = (
290                 f"Expected diff isn't equal to the actual. If you made changes "
291                 f"to expression.py and this is an anticipated difference, "
292                 f"overwrite tests/expression.diff with {dump}"
293             )
294             self.assertEqual(expected, actual, msg)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_fstring(self) -> None:
298         source, expected = read_data("fstring")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_string_quotes(self) -> None:
306         source, expected = read_data("string_quotes")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311         mode = black.FileMode.NO_STRING_NORMALIZATION
312         not_normalized = fs(source, mode=mode)
313         self.assertFormatEqual(source, not_normalized)
314         black.assert_equivalent(source, not_normalized)
315         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
316
317     @patch("black.dump_to_file", dump_to_stderr)
318     def test_slices(self) -> None:
319         source, expected = read_data("slices")
320         actual = fs(source)
321         self.assertFormatEqual(expected, actual)
322         black.assert_equivalent(source, actual)
323         black.assert_stable(source, actual, line_length=ll)
324
325     @patch("black.dump_to_file", dump_to_stderr)
326     def test_comments(self) -> None:
327         source, expected = read_data("comments")
328         actual = fs(source)
329         self.assertFormatEqual(expected, actual)
330         black.assert_equivalent(source, actual)
331         black.assert_stable(source, actual, line_length=ll)
332
333     @patch("black.dump_to_file", dump_to_stderr)
334     def test_comments2(self) -> None:
335         source, expected = read_data("comments2")
336         actual = fs(source)
337         self.assertFormatEqual(expected, actual)
338         black.assert_equivalent(source, actual)
339         black.assert_stable(source, actual, line_length=ll)
340
341     @patch("black.dump_to_file", dump_to_stderr)
342     def test_comments3(self) -> None:
343         source, expected = read_data("comments3")
344         actual = fs(source)
345         self.assertFormatEqual(expected, actual)
346         black.assert_equivalent(source, actual)
347         black.assert_stable(source, actual, line_length=ll)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_comments4(self) -> None:
351         source, expected = read_data("comments4")
352         actual = fs(source)
353         self.assertFormatEqual(expected, actual)
354         black.assert_equivalent(source, actual)
355         black.assert_stable(source, actual, line_length=ll)
356
357     @patch("black.dump_to_file", dump_to_stderr)
358     def test_comments5(self) -> None:
359         source, expected = read_data("comments5")
360         actual = fs(source)
361         self.assertFormatEqual(expected, actual)
362         black.assert_equivalent(source, actual)
363         black.assert_stable(source, actual, line_length=ll)
364
365     @patch("black.dump_to_file", dump_to_stderr)
366     def test_cantfit(self) -> None:
367         source, expected = read_data("cantfit")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         black.assert_equivalent(source, actual)
371         black.assert_stable(source, actual, line_length=ll)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_import_spacing(self) -> None:
375         source, expected = read_data("import_spacing")
376         actual = fs(source)
377         self.assertFormatEqual(expected, actual)
378         black.assert_equivalent(source, actual)
379         black.assert_stable(source, actual, line_length=ll)
380
381     @patch("black.dump_to_file", dump_to_stderr)
382     def test_composition(self) -> None:
383         source, expected = read_data("composition")
384         actual = fs(source)
385         self.assertFormatEqual(expected, actual)
386         black.assert_equivalent(source, actual)
387         black.assert_stable(source, actual, line_length=ll)
388
389     @patch("black.dump_to_file", dump_to_stderr)
390     def test_empty_lines(self) -> None:
391         source, expected = read_data("empty_lines")
392         actual = fs(source)
393         self.assertFormatEqual(expected, actual)
394         black.assert_equivalent(source, actual)
395         black.assert_stable(source, actual, line_length=ll)
396
397     @patch("black.dump_to_file", dump_to_stderr)
398     def test_string_prefixes(self) -> None:
399         source, expected = read_data("string_prefixes")
400         actual = fs(source)
401         self.assertFormatEqual(expected, actual)
402         black.assert_equivalent(source, actual)
403         black.assert_stable(source, actual, line_length=ll)
404
405     @patch("black.dump_to_file", dump_to_stderr)
406     def test_numeric_literals(self) -> None:
407         source, expected = read_data("numeric_literals")
408         actual = fs(source, mode=black.FileMode.PYTHON36)
409         self.assertFormatEqual(expected, actual)
410         black.assert_equivalent(source, actual)
411         black.assert_stable(source, actual, line_length=ll)
412
413     @patch("black.dump_to_file", dump_to_stderr)
414     def test_numeric_literals_py2(self) -> None:
415         source, expected = read_data("numeric_literals_py2")
416         actual = fs(source)
417         self.assertFormatEqual(expected, actual)
418         black.assert_stable(source, actual, line_length=ll)
419
420     @patch("black.dump_to_file", dump_to_stderr)
421     def test_python2(self) -> None:
422         source, expected = read_data("python2")
423         actual = fs(source)
424         self.assertFormatEqual(expected, actual)
425         # black.assert_equivalent(source, actual)
426         black.assert_stable(source, actual, line_length=ll)
427
428     @patch("black.dump_to_file", dump_to_stderr)
429     def test_python2_unicode_literals(self) -> None:
430         source, expected = read_data("python2_unicode_literals")
431         actual = fs(source)
432         self.assertFormatEqual(expected, actual)
433         black.assert_stable(source, actual, line_length=ll)
434
435     @patch("black.dump_to_file", dump_to_stderr)
436     def test_stub(self) -> None:
437         mode = black.FileMode.PYI
438         source, expected = read_data("stub.pyi")
439         actual = fs(source, mode=mode)
440         self.assertFormatEqual(expected, actual)
441         black.assert_stable(source, actual, line_length=ll, mode=mode)
442
443     @patch("black.dump_to_file", dump_to_stderr)
444     def test_python37(self) -> None:
445         source, expected = read_data("python37")
446         actual = fs(source)
447         self.assertFormatEqual(expected, actual)
448         major, minor = sys.version_info[:2]
449         if major > 3 or (major == 3 and minor >= 7):
450             black.assert_equivalent(source, actual)
451         black.assert_stable(source, actual, line_length=ll)
452
453     @patch("black.dump_to_file", dump_to_stderr)
454     def test_fmtonoff(self) -> None:
455         source, expected = read_data("fmtonoff")
456         actual = fs(source)
457         self.assertFormatEqual(expected, actual)
458         black.assert_equivalent(source, actual)
459         black.assert_stable(source, actual, line_length=ll)
460
461     @patch("black.dump_to_file", dump_to_stderr)
462     def test_fmtonoff2(self) -> None:
463         source, expected = read_data("fmtonoff2")
464         actual = fs(source)
465         self.assertFormatEqual(expected, actual)
466         black.assert_equivalent(source, actual)
467         black.assert_stable(source, actual, line_length=ll)
468
469     @patch("black.dump_to_file", dump_to_stderr)
470     def test_remove_empty_parentheses_after_class(self) -> None:
471         source, expected = read_data("class_blank_parentheses")
472         actual = fs(source)
473         self.assertFormatEqual(expected, actual)
474         black.assert_equivalent(source, actual)
475         black.assert_stable(source, actual, line_length=ll)
476
477     @patch("black.dump_to_file", dump_to_stderr)
478     def test_new_line_between_class_and_code(self) -> None:
479         source, expected = read_data("class_methods_new_line")
480         actual = fs(source)
481         self.assertFormatEqual(expected, actual)
482         black.assert_equivalent(source, actual)
483         black.assert_stable(source, actual, line_length=ll)
484
485     @patch("black.dump_to_file", dump_to_stderr)
486     def test_bracket_match(self) -> None:
487         source, expected = read_data("bracketmatch")
488         actual = fs(source)
489         self.assertFormatEqual(expected, actual)
490         black.assert_equivalent(source, actual)
491         black.assert_stable(source, actual, line_length=ll)
492
493     def test_report_verbose(self) -> None:
494         report = black.Report(verbose=True)
495         out_lines = []
496         err_lines = []
497
498         def out(msg: str, **kwargs: Any) -> None:
499             out_lines.append(msg)
500
501         def err(msg: str, **kwargs: Any) -> None:
502             err_lines.append(msg)
503
504         with patch("black.out", out), patch("black.err", err):
505             report.done(Path("f1"), black.Changed.NO)
506             self.assertEqual(len(out_lines), 1)
507             self.assertEqual(len(err_lines), 0)
508             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
509             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
510             self.assertEqual(report.return_code, 0)
511             report.done(Path("f2"), black.Changed.YES)
512             self.assertEqual(len(out_lines), 2)
513             self.assertEqual(len(err_lines), 0)
514             self.assertEqual(out_lines[-1], "reformatted f2")
515             self.assertEqual(
516                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
517             )
518             report.done(Path("f3"), black.Changed.CACHED)
519             self.assertEqual(len(out_lines), 3)
520             self.assertEqual(len(err_lines), 0)
521             self.assertEqual(
522                 out_lines[-1], "f3 wasn't modified on disk since last run."
523             )
524             self.assertEqual(
525                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
526             )
527             self.assertEqual(report.return_code, 0)
528             report.check = True
529             self.assertEqual(report.return_code, 1)
530             report.check = False
531             report.failed(Path("e1"), "boom")
532             self.assertEqual(len(out_lines), 3)
533             self.assertEqual(len(err_lines), 1)
534             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
535             self.assertEqual(
536                 unstyle(str(report)),
537                 "1 file reformatted, 2 files left unchanged, "
538                 "1 file failed to reformat.",
539             )
540             self.assertEqual(report.return_code, 123)
541             report.done(Path("f3"), black.Changed.YES)
542             self.assertEqual(len(out_lines), 4)
543             self.assertEqual(len(err_lines), 1)
544             self.assertEqual(out_lines[-1], "reformatted f3")
545             self.assertEqual(
546                 unstyle(str(report)),
547                 "2 files reformatted, 2 files left unchanged, "
548                 "1 file failed to reformat.",
549             )
550             self.assertEqual(report.return_code, 123)
551             report.failed(Path("e2"), "boom")
552             self.assertEqual(len(out_lines), 4)
553             self.assertEqual(len(err_lines), 2)
554             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
555             self.assertEqual(
556                 unstyle(str(report)),
557                 "2 files reformatted, 2 files left unchanged, "
558                 "2 files failed to reformat.",
559             )
560             self.assertEqual(report.return_code, 123)
561             report.path_ignored(Path("wat"), "no match")
562             self.assertEqual(len(out_lines), 5)
563             self.assertEqual(len(err_lines), 2)
564             self.assertEqual(out_lines[-1], "wat ignored: no match")
565             self.assertEqual(
566                 unstyle(str(report)),
567                 "2 files reformatted, 2 files left unchanged, "
568                 "2 files failed to reformat.",
569             )
570             self.assertEqual(report.return_code, 123)
571             report.done(Path("f4"), black.Changed.NO)
572             self.assertEqual(len(out_lines), 6)
573             self.assertEqual(len(err_lines), 2)
574             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
575             self.assertEqual(
576                 unstyle(str(report)),
577                 "2 files reformatted, 3 files left unchanged, "
578                 "2 files failed to reformat.",
579             )
580             self.assertEqual(report.return_code, 123)
581             report.check = True
582             self.assertEqual(
583                 unstyle(str(report)),
584                 "2 files would be reformatted, 3 files would be left unchanged, "
585                 "2 files would fail to reformat.",
586             )
587
588     def test_report_quiet(self) -> None:
589         report = black.Report(quiet=True)
590         out_lines = []
591         err_lines = []
592
593         def out(msg: str, **kwargs: Any) -> None:
594             out_lines.append(msg)
595
596         def err(msg: str, **kwargs: Any) -> None:
597             err_lines.append(msg)
598
599         with patch("black.out", out), patch("black.err", err):
600             report.done(Path("f1"), black.Changed.NO)
601             self.assertEqual(len(out_lines), 0)
602             self.assertEqual(len(err_lines), 0)
603             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
604             self.assertEqual(report.return_code, 0)
605             report.done(Path("f2"), black.Changed.YES)
606             self.assertEqual(len(out_lines), 0)
607             self.assertEqual(len(err_lines), 0)
608             self.assertEqual(
609                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
610             )
611             report.done(Path("f3"), black.Changed.CACHED)
612             self.assertEqual(len(out_lines), 0)
613             self.assertEqual(len(err_lines), 0)
614             self.assertEqual(
615                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
616             )
617             self.assertEqual(report.return_code, 0)
618             report.check = True
619             self.assertEqual(report.return_code, 1)
620             report.check = False
621             report.failed(Path("e1"), "boom")
622             self.assertEqual(len(out_lines), 0)
623             self.assertEqual(len(err_lines), 1)
624             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
625             self.assertEqual(
626                 unstyle(str(report)),
627                 "1 file reformatted, 2 files left unchanged, "
628                 "1 file failed to reformat.",
629             )
630             self.assertEqual(report.return_code, 123)
631             report.done(Path("f3"), black.Changed.YES)
632             self.assertEqual(len(out_lines), 0)
633             self.assertEqual(len(err_lines), 1)
634             self.assertEqual(
635                 unstyle(str(report)),
636                 "2 files reformatted, 2 files left unchanged, "
637                 "1 file failed to reformat.",
638             )
639             self.assertEqual(report.return_code, 123)
640             report.failed(Path("e2"), "boom")
641             self.assertEqual(len(out_lines), 0)
642             self.assertEqual(len(err_lines), 2)
643             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
644             self.assertEqual(
645                 unstyle(str(report)),
646                 "2 files reformatted, 2 files left unchanged, "
647                 "2 files failed to reformat.",
648             )
649             self.assertEqual(report.return_code, 123)
650             report.path_ignored(Path("wat"), "no match")
651             self.assertEqual(len(out_lines), 0)
652             self.assertEqual(len(err_lines), 2)
653             self.assertEqual(
654                 unstyle(str(report)),
655                 "2 files reformatted, 2 files left unchanged, "
656                 "2 files failed to reformat.",
657             )
658             self.assertEqual(report.return_code, 123)
659             report.done(Path("f4"), black.Changed.NO)
660             self.assertEqual(len(out_lines), 0)
661             self.assertEqual(len(err_lines), 2)
662             self.assertEqual(
663                 unstyle(str(report)),
664                 "2 files reformatted, 3 files left unchanged, "
665                 "2 files failed to reformat.",
666             )
667             self.assertEqual(report.return_code, 123)
668             report.check = True
669             self.assertEqual(
670                 unstyle(str(report)),
671                 "2 files would be reformatted, 3 files would be left unchanged, "
672                 "2 files would fail to reformat.",
673             )
674
675     def test_report_normal(self) -> None:
676         report = black.Report()
677         out_lines = []
678         err_lines = []
679
680         def out(msg: str, **kwargs: Any) -> None:
681             out_lines.append(msg)
682
683         def err(msg: str, **kwargs: Any) -> None:
684             err_lines.append(msg)
685
686         with patch("black.out", out), patch("black.err", err):
687             report.done(Path("f1"), black.Changed.NO)
688             self.assertEqual(len(out_lines), 0)
689             self.assertEqual(len(err_lines), 0)
690             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
691             self.assertEqual(report.return_code, 0)
692             report.done(Path("f2"), black.Changed.YES)
693             self.assertEqual(len(out_lines), 1)
694             self.assertEqual(len(err_lines), 0)
695             self.assertEqual(out_lines[-1], "reformatted f2")
696             self.assertEqual(
697                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
698             )
699             report.done(Path("f3"), black.Changed.CACHED)
700             self.assertEqual(len(out_lines), 1)
701             self.assertEqual(len(err_lines), 0)
702             self.assertEqual(out_lines[-1], "reformatted f2")
703             self.assertEqual(
704                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
705             )
706             self.assertEqual(report.return_code, 0)
707             report.check = True
708             self.assertEqual(report.return_code, 1)
709             report.check = False
710             report.failed(Path("e1"), "boom")
711             self.assertEqual(len(out_lines), 1)
712             self.assertEqual(len(err_lines), 1)
713             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
714             self.assertEqual(
715                 unstyle(str(report)),
716                 "1 file reformatted, 2 files left unchanged, "
717                 "1 file failed to reformat.",
718             )
719             self.assertEqual(report.return_code, 123)
720             report.done(Path("f3"), black.Changed.YES)
721             self.assertEqual(len(out_lines), 2)
722             self.assertEqual(len(err_lines), 1)
723             self.assertEqual(out_lines[-1], "reformatted f3")
724             self.assertEqual(
725                 unstyle(str(report)),
726                 "2 files reformatted, 2 files left unchanged, "
727                 "1 file failed to reformat.",
728             )
729             self.assertEqual(report.return_code, 123)
730             report.failed(Path("e2"), "boom")
731             self.assertEqual(len(out_lines), 2)
732             self.assertEqual(len(err_lines), 2)
733             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
734             self.assertEqual(
735                 unstyle(str(report)),
736                 "2 files reformatted, 2 files left unchanged, "
737                 "2 files failed to reformat.",
738             )
739             self.assertEqual(report.return_code, 123)
740             report.path_ignored(Path("wat"), "no match")
741             self.assertEqual(len(out_lines), 2)
742             self.assertEqual(len(err_lines), 2)
743             self.assertEqual(
744                 unstyle(str(report)),
745                 "2 files reformatted, 2 files left unchanged, "
746                 "2 files failed to reformat.",
747             )
748             self.assertEqual(report.return_code, 123)
749             report.done(Path("f4"), black.Changed.NO)
750             self.assertEqual(len(out_lines), 2)
751             self.assertEqual(len(err_lines), 2)
752             self.assertEqual(
753                 unstyle(str(report)),
754                 "2 files reformatted, 3 files left unchanged, "
755                 "2 files failed to reformat.",
756             )
757             self.assertEqual(report.return_code, 123)
758             report.check = True
759             self.assertEqual(
760                 unstyle(str(report)),
761                 "2 files would be reformatted, 3 files would be left unchanged, "
762                 "2 files would fail to reformat.",
763             )
764
765     def test_is_python36(self) -> None:
766         node = black.lib2to3_parse("def f(*, arg): ...\n")
767         self.assertFalse(black.is_python36(node))
768         node = black.lib2to3_parse("def f(*, arg,): ...\n")
769         self.assertTrue(black.is_python36(node))
770         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
771         self.assertTrue(black.is_python36(node))
772         node = black.lib2to3_parse("123_456\n")
773         self.assertTrue(black.is_python36(node))
774         node = black.lib2to3_parse("123456\n")
775         self.assertFalse(black.is_python36(node))
776         source, expected = read_data("function")
777         node = black.lib2to3_parse(source)
778         self.assertTrue(black.is_python36(node))
779         node = black.lib2to3_parse(expected)
780         self.assertTrue(black.is_python36(node))
781         source, expected = read_data("expression")
782         node = black.lib2to3_parse(source)
783         self.assertFalse(black.is_python36(node))
784         node = black.lib2to3_parse(expected)
785         self.assertFalse(black.is_python36(node))
786
787     def test_get_future_imports(self) -> None:
788         node = black.lib2to3_parse("\n")
789         self.assertEqual(set(), black.get_future_imports(node))
790         node = black.lib2to3_parse("from __future__ import black\n")
791         self.assertEqual({"black"}, black.get_future_imports(node))
792         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
793         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
794         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
795         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
796         node = black.lib2to3_parse(
797             "from __future__ import multiple\nfrom __future__ import imports\n"
798         )
799         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
800         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
801         self.assertEqual({"black"}, black.get_future_imports(node))
802         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
803         self.assertEqual({"black"}, black.get_future_imports(node))
804         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
805         self.assertEqual(set(), black.get_future_imports(node))
806         node = black.lib2to3_parse("from some.module import black\n")
807         self.assertEqual(set(), black.get_future_imports(node))
808         node = black.lib2to3_parse(
809             "from __future__ import unicode_literals as _unicode_literals"
810         )
811         self.assertEqual({"unicode_literals"}, black.get_future_imports(node))
812         node = black.lib2to3_parse(
813             "from __future__ import unicode_literals as _lol, print"
814         )
815         self.assertEqual({"unicode_literals", "print"}, black.get_future_imports(node))
816
817     def test_debug_visitor(self) -> None:
818         source, _ = read_data("debug_visitor.py")
819         expected, _ = read_data("debug_visitor.out")
820         out_lines = []
821         err_lines = []
822
823         def out(msg: str, **kwargs: Any) -> None:
824             out_lines.append(msg)
825
826         def err(msg: str, **kwargs: Any) -> None:
827             err_lines.append(msg)
828
829         with patch("black.out", out), patch("black.err", err):
830             black.DebugVisitor.show(source)
831         actual = "\n".join(out_lines) + "\n"
832         log_name = ""
833         if expected != actual:
834             log_name = black.dump_to_file(*out_lines)
835         self.assertEqual(
836             expected,
837             actual,
838             f"AST print out is different. Actual version dumped to {log_name}",
839         )
840
841     def test_format_file_contents(self) -> None:
842         empty = ""
843         with self.assertRaises(black.NothingChanged):
844             black.format_file_contents(empty, line_length=ll, fast=False)
845         just_nl = "\n"
846         with self.assertRaises(black.NothingChanged):
847             black.format_file_contents(just_nl, line_length=ll, fast=False)
848         same = "l = [1, 2, 3]\n"
849         with self.assertRaises(black.NothingChanged):
850             black.format_file_contents(same, line_length=ll, fast=False)
851         different = "l = [1,2,3]"
852         expected = same
853         actual = black.format_file_contents(different, line_length=ll, fast=False)
854         self.assertEqual(expected, actual)
855         invalid = "return if you can"
856         with self.assertRaises(black.InvalidInput) as e:
857             black.format_file_contents(invalid, line_length=ll, fast=False)
858         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
859
860     def test_endmarker(self) -> None:
861         n = black.lib2to3_parse("\n")
862         self.assertEqual(n.type, black.syms.file_input)
863         self.assertEqual(len(n.children), 1)
864         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
865
866     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
867     def test_assertFormatEqual(self) -> None:
868         out_lines = []
869         err_lines = []
870
871         def out(msg: str, **kwargs: Any) -> None:
872             out_lines.append(msg)
873
874         def err(msg: str, **kwargs: Any) -> None:
875             err_lines.append(msg)
876
877         with patch("black.out", out), patch("black.err", err):
878             with self.assertRaises(AssertionError):
879                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
880
881         out_str = "".join(out_lines)
882         self.assertTrue("Expected tree:" in out_str)
883         self.assertTrue("Actual tree:" in out_str)
884         self.assertEqual("".join(err_lines), "")
885
886     def test_cache_broken_file(self) -> None:
887         mode = black.FileMode.AUTO_DETECT
888         with cache_dir() as workspace:
889             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
890             with cache_file.open("w") as fobj:
891                 fobj.write("this is not a pickle")
892             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
893             src = (workspace / "test.py").resolve()
894             with src.open("w") as fobj:
895                 fobj.write("print('hello')")
896             result = CliRunner().invoke(black.main, [str(src)])
897             self.assertEqual(result.exit_code, 0)
898             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
899             self.assertIn(src, cache)
900
901     def test_cache_single_file_already_cached(self) -> None:
902         mode = black.FileMode.AUTO_DETECT
903         with cache_dir() as workspace:
904             src = (workspace / "test.py").resolve()
905             with src.open("w") as fobj:
906                 fobj.write("print('hello')")
907             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
908             result = CliRunner().invoke(black.main, [str(src)])
909             self.assertEqual(result.exit_code, 0)
910             with src.open("r") as fobj:
911                 self.assertEqual(fobj.read(), "print('hello')")
912
913     @event_loop(close=False)
914     def test_cache_multiple_files(self) -> None:
915         mode = black.FileMode.AUTO_DETECT
916         with cache_dir() as workspace, patch(
917             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
918         ):
919             one = (workspace / "one.py").resolve()
920             with one.open("w") as fobj:
921                 fobj.write("print('hello')")
922             two = (workspace / "two.py").resolve()
923             with two.open("w") as fobj:
924                 fobj.write("print('hello')")
925             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
926             result = CliRunner().invoke(black.main, [str(workspace)])
927             self.assertEqual(result.exit_code, 0)
928             with one.open("r") as fobj:
929                 self.assertEqual(fobj.read(), "print('hello')")
930             with two.open("r") as fobj:
931                 self.assertEqual(fobj.read(), 'print("hello")\n')
932             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
933             self.assertIn(one, cache)
934             self.assertIn(two, cache)
935
936     def test_no_cache_when_writeback_diff(self) -> None:
937         mode = black.FileMode.AUTO_DETECT
938         with cache_dir() as workspace:
939             src = (workspace / "test.py").resolve()
940             with src.open("w") as fobj:
941                 fobj.write("print('hello')")
942             result = CliRunner().invoke(black.main, [str(src), "--diff"])
943             self.assertEqual(result.exit_code, 0)
944             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
945             self.assertFalse(cache_file.exists())
946
947     def test_no_cache_when_stdin(self) -> None:
948         mode = black.FileMode.AUTO_DETECT
949         with cache_dir():
950             result = CliRunner().invoke(
951                 black.main, ["-"], input=BytesIO(b"print('hello')")
952             )
953             self.assertEqual(result.exit_code, 0)
954             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
955             self.assertFalse(cache_file.exists())
956
957     def test_read_cache_no_cachefile(self) -> None:
958         mode = black.FileMode.AUTO_DETECT
959         with cache_dir():
960             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
961
962     def test_write_cache_read_cache(self) -> None:
963         mode = black.FileMode.AUTO_DETECT
964         with cache_dir() as workspace:
965             src = (workspace / "test.py").resolve()
966             src.touch()
967             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
968             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
969             self.assertIn(src, cache)
970             self.assertEqual(cache[src], black.get_cache_info(src))
971
972     def test_filter_cached(self) -> None:
973         with TemporaryDirectory() as workspace:
974             path = Path(workspace)
975             uncached = (path / "uncached").resolve()
976             cached = (path / "cached").resolve()
977             cached_but_changed = (path / "changed").resolve()
978             uncached.touch()
979             cached.touch()
980             cached_but_changed.touch()
981             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
982             todo, done = black.filter_cached(
983                 cache, {uncached, cached, cached_but_changed}
984             )
985             self.assertEqual(todo, {uncached, cached_but_changed})
986             self.assertEqual(done, {cached})
987
988     def test_write_cache_creates_directory_if_needed(self) -> None:
989         mode = black.FileMode.AUTO_DETECT
990         with cache_dir(exists=False) as workspace:
991             self.assertFalse(workspace.exists())
992             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
993             self.assertTrue(workspace.exists())
994
995     @event_loop(close=False)
996     def test_failed_formatting_does_not_get_cached(self) -> None:
997         mode = black.FileMode.AUTO_DETECT
998         with cache_dir() as workspace, patch(
999             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
1000         ):
1001             failing = (workspace / "failing.py").resolve()
1002             with failing.open("w") as fobj:
1003                 fobj.write("not actually python")
1004             clean = (workspace / "clean.py").resolve()
1005             with clean.open("w") as fobj:
1006                 fobj.write('print("hello")\n')
1007             result = CliRunner().invoke(black.main, [str(workspace)])
1008             self.assertEqual(result.exit_code, 123)
1009             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
1010             self.assertNotIn(failing, cache)
1011             self.assertIn(clean, cache)
1012
1013     def test_write_cache_write_fail(self) -> None:
1014         mode = black.FileMode.AUTO_DETECT
1015         with cache_dir(), patch.object(Path, "open") as mock:
1016             mock.side_effect = OSError
1017             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
1018
1019     @event_loop(close=False)
1020     def test_check_diff_use_together(self) -> None:
1021         with cache_dir():
1022             # Files which will be reformatted.
1023             src1 = (THIS_DIR / "data" / "string_quotes.py").resolve()
1024             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
1025             self.assertEqual(result.exit_code, 1, result.output)
1026             # Files which will not be reformatted.
1027             src2 = (THIS_DIR / "data" / "composition.py").resolve()
1028             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
1029             self.assertEqual(result.exit_code, 0, result.output)
1030             # Multi file command.
1031             result = CliRunner().invoke(
1032                 black.main, [str(src1), str(src2), "--diff", "--check"]
1033             )
1034             self.assertEqual(result.exit_code, 1, result.output)
1035
1036     def test_no_files(self) -> None:
1037         with cache_dir():
1038             # Without an argument, black exits with error code 0.
1039             result = CliRunner().invoke(black.main, [])
1040             self.assertEqual(result.exit_code, 0)
1041
1042     def test_broken_symlink(self) -> None:
1043         with cache_dir() as workspace:
1044             symlink = workspace / "broken_link.py"
1045             try:
1046                 symlink.symlink_to("nonexistent.py")
1047             except OSError as e:
1048                 self.skipTest(f"Can't create symlinks: {e}")
1049             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
1050             self.assertEqual(result.exit_code, 0)
1051
1052     def test_read_cache_line_lengths(self) -> None:
1053         mode = black.FileMode.AUTO_DETECT
1054         with cache_dir() as workspace:
1055             path = (workspace / "file.py").resolve()
1056             path.touch()
1057             black.write_cache({}, [path], 1, mode)
1058             one = black.read_cache(1, mode)
1059             self.assertIn(path, one)
1060             two = black.read_cache(2, mode)
1061             self.assertNotIn(path, two)
1062
1063     def test_single_file_force_pyi(self) -> None:
1064         reg_mode = black.FileMode.AUTO_DETECT
1065         pyi_mode = black.FileMode.PYI
1066         contents, expected = read_data("force_pyi")
1067         with cache_dir() as workspace:
1068             path = (workspace / "file.py").resolve()
1069             with open(path, "w") as fh:
1070                 fh.write(contents)
1071             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
1072             self.assertEqual(result.exit_code, 0)
1073             with open(path, "r") as fh:
1074                 actual = fh.read()
1075             # verify cache with --pyi is separate
1076             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1077             self.assertIn(path, pyi_cache)
1078             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1079             self.assertNotIn(path, normal_cache)
1080         self.assertEqual(actual, expected)
1081
1082     @event_loop(close=False)
1083     def test_multi_file_force_pyi(self) -> None:
1084         reg_mode = black.FileMode.AUTO_DETECT
1085         pyi_mode = black.FileMode.PYI
1086         contents, expected = read_data("force_pyi")
1087         with cache_dir() as workspace:
1088             paths = [
1089                 (workspace / "file1.py").resolve(),
1090                 (workspace / "file2.py").resolve(),
1091             ]
1092             for path in paths:
1093                 with open(path, "w") as fh:
1094                     fh.write(contents)
1095             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
1096             self.assertEqual(result.exit_code, 0)
1097             for path in paths:
1098                 with open(path, "r") as fh:
1099                     actual = fh.read()
1100                 self.assertEqual(actual, expected)
1101             # verify cache with --pyi is separate
1102             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
1103             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1104             for path in paths:
1105                 self.assertIn(path, pyi_cache)
1106                 self.assertNotIn(path, normal_cache)
1107
1108     def test_pipe_force_pyi(self) -> None:
1109         source, expected = read_data("force_pyi")
1110         result = CliRunner().invoke(
1111             black.main, ["-", "-q", "--pyi"], input=BytesIO(source.encode("utf8"))
1112         )
1113         self.assertEqual(result.exit_code, 0)
1114         actual = result.output
1115         self.assertFormatEqual(actual, expected)
1116
1117     def test_single_file_force_py36(self) -> None:
1118         reg_mode = black.FileMode.AUTO_DETECT
1119         py36_mode = black.FileMode.PYTHON36
1120         source, expected = read_data("force_py36")
1121         with cache_dir() as workspace:
1122             path = (workspace / "file.py").resolve()
1123             with open(path, "w") as fh:
1124                 fh.write(source)
1125             result = CliRunner().invoke(black.main, [str(path), "--py36"])
1126             self.assertEqual(result.exit_code, 0)
1127             with open(path, "r") as fh:
1128                 actual = fh.read()
1129             # verify cache with --py36 is separate
1130             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1131             self.assertIn(path, py36_cache)
1132             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1133             self.assertNotIn(path, normal_cache)
1134         self.assertEqual(actual, expected)
1135
1136     @event_loop(close=False)
1137     def test_multi_file_force_py36(self) -> None:
1138         reg_mode = black.FileMode.AUTO_DETECT
1139         py36_mode = black.FileMode.PYTHON36
1140         source, expected = read_data("force_py36")
1141         with cache_dir() as workspace:
1142             paths = [
1143                 (workspace / "file1.py").resolve(),
1144                 (workspace / "file2.py").resolve(),
1145             ]
1146             for path in paths:
1147                 with open(path, "w") as fh:
1148                     fh.write(source)
1149             result = CliRunner().invoke(
1150                 black.main, [str(p) for p in paths] + ["--py36"]
1151             )
1152             self.assertEqual(result.exit_code, 0)
1153             for path in paths:
1154                 with open(path, "r") as fh:
1155                     actual = fh.read()
1156                 self.assertEqual(actual, expected)
1157             # verify cache with --py36 is separate
1158             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1159             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1160             for path in paths:
1161                 self.assertIn(path, pyi_cache)
1162                 self.assertNotIn(path, normal_cache)
1163
1164     def test_pipe_force_py36(self) -> None:
1165         source, expected = read_data("force_py36")
1166         result = CliRunner().invoke(
1167             black.main, ["-", "-q", "--py36"], input=BytesIO(source.encode("utf8"))
1168         )
1169         self.assertEqual(result.exit_code, 0)
1170         actual = result.output
1171         self.assertFormatEqual(actual, expected)
1172
1173     def test_include_exclude(self) -> None:
1174         path = THIS_DIR / "data" / "include_exclude_tests"
1175         include = re.compile(r"\.pyi?$")
1176         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1177         report = black.Report()
1178         sources: List[Path] = []
1179         expected = [
1180             Path(path / "b/dont_exclude/a.py"),
1181             Path(path / "b/dont_exclude/a.pyi"),
1182         ]
1183         this_abs = THIS_DIR.resolve()
1184         sources.extend(
1185             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1186         )
1187         self.assertEqual(sorted(expected), sorted(sources))
1188
1189     def test_empty_include(self) -> None:
1190         path = THIS_DIR / "data" / "include_exclude_tests"
1191         report = black.Report()
1192         empty = re.compile(r"")
1193         sources: List[Path] = []
1194         expected = [
1195             Path(path / "b/exclude/a.pie"),
1196             Path(path / "b/exclude/a.py"),
1197             Path(path / "b/exclude/a.pyi"),
1198             Path(path / "b/dont_exclude/a.pie"),
1199             Path(path / "b/dont_exclude/a.py"),
1200             Path(path / "b/dont_exclude/a.pyi"),
1201             Path(path / "b/.definitely_exclude/a.pie"),
1202             Path(path / "b/.definitely_exclude/a.py"),
1203             Path(path / "b/.definitely_exclude/a.pyi"),
1204         ]
1205         this_abs = THIS_DIR.resolve()
1206         sources.extend(
1207             black.gen_python_files_in_dir(
1208                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1209             )
1210         )
1211         self.assertEqual(sorted(expected), sorted(sources))
1212
1213     def test_empty_exclude(self) -> None:
1214         path = THIS_DIR / "data" / "include_exclude_tests"
1215         report = black.Report()
1216         empty = re.compile(r"")
1217         sources: List[Path] = []
1218         expected = [
1219             Path(path / "b/dont_exclude/a.py"),
1220             Path(path / "b/dont_exclude/a.pyi"),
1221             Path(path / "b/exclude/a.py"),
1222             Path(path / "b/exclude/a.pyi"),
1223             Path(path / "b/.definitely_exclude/a.py"),
1224             Path(path / "b/.definitely_exclude/a.pyi"),
1225         ]
1226         this_abs = THIS_DIR.resolve()
1227         sources.extend(
1228             black.gen_python_files_in_dir(
1229                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1230             )
1231         )
1232         self.assertEqual(sorted(expected), sorted(sources))
1233
1234     def test_invalid_include_exclude(self) -> None:
1235         for option in ["--include", "--exclude"]:
1236             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
1237             self.assertEqual(result.exit_code, 2)
1238
1239     def test_preserves_line_endings(self) -> None:
1240         with TemporaryDirectory() as workspace:
1241             test_file = Path(workspace) / "test.py"
1242             for nl in ["\n", "\r\n"]:
1243                 contents = nl.join(["def f(  ):", "    pass"])
1244                 test_file.write_bytes(contents.encode())
1245                 ff(test_file, write_back=black.WriteBack.YES)
1246                 updated_contents: bytes = test_file.read_bytes()
1247                 self.assertIn(nl.encode(), updated_contents)
1248                 if nl == "\n":
1249                     self.assertNotIn(b"\r\n", updated_contents)
1250
1251     def test_preserves_line_endings_via_stdin(self) -> None:
1252         for nl in ["\n", "\r\n"]:
1253             contents = nl.join(["def f(  ):", "    pass"])
1254             runner = BlackRunner()
1255             result = runner.invoke(
1256                 black.main, ["-", "--fast"], input=BytesIO(contents.encode("utf8"))
1257             )
1258             self.assertEqual(result.exit_code, 0)
1259             output = runner.stdout_bytes
1260             self.assertIn(nl.encode("utf8"), output)
1261             if nl == "\n":
1262                 self.assertNotIn(b"\r\n", output)
1263
1264     def test_assert_equivalent_different_asts(self) -> None:
1265         with self.assertRaises(AssertionError):
1266             black.assert_equivalent("{}", "None")
1267
1268     def test_symlink_out_of_root_directory(self) -> None:
1269         path = MagicMock()
1270         root = THIS_DIR
1271         child = MagicMock()
1272         include = re.compile(black.DEFAULT_INCLUDES)
1273         exclude = re.compile(black.DEFAULT_EXCLUDES)
1274         report = black.Report()
1275         # `child` should behave like a symlink which resolved path is clearly
1276         # outside of the `root` directory.
1277         path.iterdir.return_value = [child]
1278         child.resolve.return_value = Path("/a/b/c")
1279         child.is_symlink.return_value = True
1280         try:
1281             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1282         except ValueError as ve:
1283             self.fail("`get_python_files_in_dir()` failed: {ve}")
1284         path.iterdir.assert_called_once()
1285         child.resolve.assert_called_once()
1286         child.is_symlink.assert_called_once()
1287         # `child` should behave like a strange file which resolved path is clearly
1288         # outside of the `root` directory.
1289         child.is_symlink.return_value = False
1290         with self.assertRaises(ValueError):
1291             list(black.gen_python_files_in_dir(path, root, include, exclude, report))
1292         path.iterdir.assert_called()
1293         self.assertEqual(path.iterdir.call_count, 2)
1294         child.resolve.assert_called()
1295         self.assertEqual(child.resolve.call_count, 2)
1296         child.is_symlink.assert_called()
1297         self.assertEqual(child.is_symlink.call_count, 2)
1298
1299     def test_shhh_click(self) -> None:
1300         try:
1301             from click import _unicodefun  # type: ignore
1302         except ModuleNotFoundError:
1303             self.skipTest("Incompatible Click version")
1304         if not hasattr(_unicodefun, "_verify_python3_env"):
1305             self.skipTest("Incompatible Click version")
1306         # First, let's see if Click is crashing with a preferred ASCII charset.
1307         with patch("locale.getpreferredencoding") as gpe:
1308             gpe.return_value = "ASCII"
1309             with self.assertRaises(RuntimeError):
1310                 _unicodefun._verify_python3_env()
1311         # Now, let's silence Click...
1312         black.patch_click()
1313         # ...and confirm it's silent.
1314         with patch("locale.getpreferredencoding") as gpe:
1315             gpe.return_value = "ASCII"
1316             try:
1317                 _unicodefun._verify_python3_env()
1318             except RuntimeError as re:
1319                 self.fail(f"`patch_click()` failed, exception still raised: {re}")
1320
1321     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1322     @async_test
1323     async def test_blackd_request_needs_formatting(self) -> None:
1324         app = blackd.make_app()
1325         async with TestClient(TestServer(app)) as client:
1326             response = await client.post("/", data=b"print('hello world')")
1327             self.assertEqual(response.status, 200)
1328             self.assertEqual(response.charset, "utf8")
1329             self.assertEqual(await response.read(), b'print("hello world")\n')
1330
1331     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1332     @async_test
1333     async def test_blackd_request_no_change(self) -> None:
1334         app = blackd.make_app()
1335         async with TestClient(TestServer(app)) as client:
1336             response = await client.post("/", data=b'print("hello world")\n')
1337             self.assertEqual(response.status, 204)
1338             self.assertEqual(await response.read(), b"")
1339
1340     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1341     @async_test
1342     async def test_blackd_request_syntax_error(self) -> None:
1343         app = blackd.make_app()
1344         async with TestClient(TestServer(app)) as client:
1345             response = await client.post("/", data=b"what even ( is")
1346             self.assertEqual(response.status, 400)
1347             content = await response.text()
1348             self.assertTrue(
1349                 content.startswith("Cannot parse"),
1350                 msg=f"Expected error to start with 'Cannot parse', got {repr(content)}",
1351             )
1352
1353     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1354     @async_test
1355     async def test_blackd_unsupported_version(self) -> None:
1356         app = blackd.make_app()
1357         async with TestClient(TestServer(app)) as client:
1358             response = await client.post(
1359                 "/", data=b"what", headers={blackd.VERSION_HEADER: "2"}
1360             )
1361             self.assertEqual(response.status, 501)
1362
1363     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1364     @async_test
1365     async def test_blackd_supported_version(self) -> None:
1366         app = blackd.make_app()
1367         async with TestClient(TestServer(app)) as client:
1368             response = await client.post(
1369                 "/", data=b"what", headers={blackd.VERSION_HEADER: "1"}
1370             )
1371             self.assertEqual(response.status, 200)
1372
1373     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1374     @async_test
1375     async def test_blackd_invalid_python_variant(self) -> None:
1376         app = blackd.make_app()
1377         async with TestClient(TestServer(app)) as client:
1378             response = await client.post(
1379                 "/", data=b"what", headers={blackd.PYTHON_VARIANT_HEADER: "lol"}
1380             )
1381             self.assertEqual(response.status, 400)
1382
1383     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1384     @async_test
1385     async def test_blackd_pyi(self) -> None:
1386         app = blackd.make_app()
1387         async with TestClient(TestServer(app)) as client:
1388             source, expected = read_data("stub.pyi")
1389             response = await client.post(
1390                 "/", data=source, headers={blackd.PYTHON_VARIANT_HEADER: "pyi"}
1391             )
1392             self.assertEqual(response.status, 200)
1393             self.assertEqual(await response.text(), expected)
1394
1395     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1396     @async_test
1397     async def test_blackd_py36(self) -> None:
1398         app = blackd.make_app()
1399         async with TestClient(TestServer(app)) as client:
1400             response = await client.post(
1401                 "/",
1402                 data=(
1403                     "def f(\n"
1404                     "    and_has_a_bunch_of,\n"
1405                     "    very_long_arguments_too,\n"
1406                     "    and_lots_of_them_as_well_lol,\n"
1407                     "    **and_very_long_keyword_arguments\n"
1408                     "):\n"
1409                     "    pass\n"
1410                 ),
1411                 headers={blackd.PYTHON_VARIANT_HEADER: "3.6"},
1412             )
1413             self.assertEqual(response.status, 200)
1414             response = await client.post(
1415                 "/",
1416                 data=(
1417                     "def f(\n"
1418                     "    and_has_a_bunch_of,\n"
1419                     "    very_long_arguments_too,\n"
1420                     "    and_lots_of_them_as_well_lol,\n"
1421                     "    **and_very_long_keyword_arguments\n"
1422                     "):\n"
1423                     "    pass\n"
1424                 ),
1425                 headers={blackd.PYTHON_VARIANT_HEADER: "3.5"},
1426             )
1427             self.assertEqual(response.status, 204)
1428             response = await client.post(
1429                 "/",
1430                 data=(
1431                     "def f(\n"
1432                     "    and_has_a_bunch_of,\n"
1433                     "    very_long_arguments_too,\n"
1434                     "    and_lots_of_them_as_well_lol,\n"
1435                     "    **and_very_long_keyword_arguments\n"
1436                     "):\n"
1437                     "    pass\n"
1438                 ),
1439                 headers={blackd.PYTHON_VARIANT_HEADER: "2"},
1440             )
1441             self.assertEqual(response.status, 204)
1442
1443     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1444     @async_test
1445     async def test_blackd_fast(self) -> None:
1446         app = blackd.make_app()
1447         async with TestClient(TestServer(app)) as client:
1448             response = await client.post("/", data=b"ur'hello'")
1449             self.assertEqual(response.status, 500)
1450             self.assertIn("failed to parse source file", await response.text())
1451             response = await client.post(
1452                 "/", data=b"ur'hello'", headers={blackd.FAST_OR_SAFE_HEADER: "fast"}
1453             )
1454             self.assertEqual(response.status, 200)
1455
1456     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1457     @async_test
1458     async def test_blackd_line_length(self) -> None:
1459         app = blackd.make_app()
1460         async with TestClient(TestServer(app)) as client:
1461             response = await client.post(
1462                 "/", data=b'print("hello")\n', headers={blackd.LINE_LENGTH_HEADER: "7"}
1463             )
1464             self.assertEqual(response.status, 200)
1465
1466     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1467     @async_test
1468     async def test_blackd_invalid_line_length(self) -> None:
1469         app = blackd.make_app()
1470         async with TestClient(TestServer(app)) as client:
1471             response = await client.post(
1472                 "/",
1473                 data=b'print("hello")\n',
1474                 headers={blackd.LINE_LENGTH_HEADER: "NaN"},
1475             )
1476             self.assertEqual(response.status, 400)
1477
1478     @unittest.skipUnless(has_blackd_deps, "blackd's dependencies are not installed")
1479     def test_blackd_main(self) -> None:
1480         with patch("blackd.web.run_app"):
1481             result = CliRunner().invoke(blackd.main, [])
1482             if result.exception is not None:
1483                 raise result.exception
1484             self.assertEqual(result.exit_code, 0)
1485
1486
1487 if __name__ == "__main__":
1488     unittest.main(module="test_black")