]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Preserve line endings when formatting a file in place (#288)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import BytesIO, TextIOWrapper
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14 import re
15
16 from click import unstyle
17 from click.testing import CliRunner
18
19 import black
20
21 ll = 88
22 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
23 fs = partial(black.format_str, line_length=ll)
24 THIS_FILE = Path(__file__)
25 THIS_DIR = THIS_FILE.parent
26 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
27
28
29 def dump_to_stderr(*output: str) -> str:
30     return "\n" + "\n".join(output) + "\n"
31
32
33 def read_data(name: str) -> Tuple[str, str]:
34     """read_data('test_name') -> 'input', 'output'"""
35     if not name.endswith((".py", ".pyi", ".out", ".diff")):
36         name += ".py"
37     _input: List[str] = []
38     _output: List[str] = []
39     with open(THIS_DIR / name, "r", encoding="utf8") as test:
40         lines = test.readlines()
41     result = _input
42     for line in lines:
43         line = line.replace(EMPTY_LINE, "")
44         if line.rstrip() == "# output":
45             result = _output
46             continue
47
48         result.append(line)
49     if _input and not _output:
50         # If there's no output marker, treat the entire file as already pre-formatted.
51         _output = _input[:]
52     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
53
54
55 @contextmanager
56 def cache_dir(exists: bool = True) -> Iterator[Path]:
57     with TemporaryDirectory() as workspace:
58         cache_dir = Path(workspace)
59         if not exists:
60             cache_dir = cache_dir / "new"
61         with patch("black.CACHE_DIR", cache_dir):
62             yield cache_dir
63
64
65 @contextmanager
66 def event_loop(close: bool) -> Iterator[None]:
67     policy = asyncio.get_event_loop_policy()
68     old_loop = policy.get_event_loop()
69     loop = policy.new_event_loop()
70     asyncio.set_event_loop(loop)
71     try:
72         yield
73
74     finally:
75         policy.set_event_loop(old_loop)
76         if close:
77             loop.close()
78
79
80 class BlackTestCase(unittest.TestCase):
81     maxDiff = None
82
83     def assertFormatEqual(self, expected: str, actual: str) -> None:
84         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
85             bdv: black.DebugVisitor[Any]
86             black.out("Expected tree:", fg="green")
87             try:
88                 exp_node = black.lib2to3_parse(expected)
89                 bdv = black.DebugVisitor()
90                 list(bdv.visit(exp_node))
91             except Exception as ve:
92                 black.err(str(ve))
93             black.out("Actual tree:", fg="red")
94             try:
95                 exp_node = black.lib2to3_parse(actual)
96                 bdv = black.DebugVisitor()
97                 list(bdv.visit(exp_node))
98             except Exception as ve:
99                 black.err(str(ve))
100         self.assertEqual(expected, actual)
101
102     @patch("black.dump_to_file", dump_to_stderr)
103     def test_self(self) -> None:
104         source, expected = read_data("test_black")
105         actual = fs(source)
106         self.assertFormatEqual(expected, actual)
107         black.assert_equivalent(source, actual)
108         black.assert_stable(source, actual, line_length=ll)
109         self.assertFalse(ff(THIS_FILE))
110
111     @patch("black.dump_to_file", dump_to_stderr)
112     def test_black(self) -> None:
113         source, expected = read_data("../black")
114         actual = fs(source)
115         self.assertFormatEqual(expected, actual)
116         black.assert_equivalent(source, actual)
117         black.assert_stable(source, actual, line_length=ll)
118         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
119
120     def test_piping(self) -> None:
121         source, expected = read_data("../black")
122         hold_stdin, hold_stdout = sys.stdin, sys.stdout
123         try:
124             sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
125             sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
126             sys.stdin.buffer.name = "<stdin>"  # type: ignore
127             black.format_stdin_to_stdout(
128                 line_length=ll, fast=True, write_back=black.WriteBack.YES
129             )
130             sys.stdout.seek(0)
131             actual = sys.stdout.read()
132         finally:
133             sys.stdin, sys.stdout = hold_stdin, hold_stdout
134         self.assertFormatEqual(expected, actual)
135         black.assert_equivalent(source, actual)
136         black.assert_stable(source, actual, line_length=ll)
137
138     def test_piping_diff(self) -> None:
139         source, _ = read_data("expression.py")
140         expected, _ = read_data("expression.diff")
141         hold_stdin, hold_stdout = sys.stdin, sys.stdout
142         try:
143             sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
144             sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
145             sys.stdin.buffer.name = "<stdin>"  # type: ignore
146             black.format_stdin_to_stdout(
147                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
148             )
149             sys.stdout.seek(0)
150             actual = sys.stdout.read()
151         finally:
152             sys.stdin, sys.stdout = hold_stdin, hold_stdout
153         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
154         self.assertEqual(expected, actual)
155
156     @patch("black.dump_to_file", dump_to_stderr)
157     def test_setup(self) -> None:
158         source, expected = read_data("../setup")
159         actual = fs(source)
160         self.assertFormatEqual(expected, actual)
161         black.assert_equivalent(source, actual)
162         black.assert_stable(source, actual, line_length=ll)
163         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
164
165     @patch("black.dump_to_file", dump_to_stderr)
166     def test_function(self) -> None:
167         source, expected = read_data("function")
168         actual = fs(source)
169         self.assertFormatEqual(expected, actual)
170         black.assert_equivalent(source, actual)
171         black.assert_stable(source, actual, line_length=ll)
172
173     @patch("black.dump_to_file", dump_to_stderr)
174     def test_function2(self) -> None:
175         source, expected = read_data("function2")
176         actual = fs(source)
177         self.assertFormatEqual(expected, actual)
178         black.assert_equivalent(source, actual)
179         black.assert_stable(source, actual, line_length=ll)
180
181     @patch("black.dump_to_file", dump_to_stderr)
182     def test_expression(self) -> None:
183         source, expected = read_data("expression")
184         actual = fs(source)
185         self.assertFormatEqual(expected, actual)
186         black.assert_equivalent(source, actual)
187         black.assert_stable(source, actual, line_length=ll)
188
189     def test_expression_ff(self) -> None:
190         source, expected = read_data("expression")
191         tmp_file = Path(black.dump_to_file(source))
192         try:
193             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
194             with open(tmp_file, encoding="utf8") as f:
195                 actual = f.read()
196         finally:
197             os.unlink(tmp_file)
198         self.assertFormatEqual(expected, actual)
199         with patch("black.dump_to_file", dump_to_stderr):
200             black.assert_equivalent(source, actual)
201             black.assert_stable(source, actual, line_length=ll)
202
203     def test_expression_diff(self) -> None:
204         source, _ = read_data("expression.py")
205         expected, _ = read_data("expression.diff")
206         tmp_file = Path(black.dump_to_file(source))
207         hold_stdout = sys.stdout
208         try:
209             sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
210             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
211             sys.stdout.seek(0)
212             actual = sys.stdout.read()
213             actual = actual.replace(str(tmp_file), "<stdin>")
214         finally:
215             sys.stdout = hold_stdout
216             os.unlink(tmp_file)
217         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
218         if expected != actual:
219             dump = black.dump_to_file(actual)
220             msg = (
221                 f"Expected diff isn't equal to the actual. If you made changes "
222                 f"to expression.py and this is an anticipated difference, "
223                 f"overwrite tests/expression.diff with {dump}"
224             )
225             self.assertEqual(expected, actual, msg)
226
227     @patch("black.dump_to_file", dump_to_stderr)
228     def test_fstring(self) -> None:
229         source, expected = read_data("fstring")
230         actual = fs(source)
231         self.assertFormatEqual(expected, actual)
232         black.assert_equivalent(source, actual)
233         black.assert_stable(source, actual, line_length=ll)
234
235     @patch("black.dump_to_file", dump_to_stderr)
236     def test_string_quotes(self) -> None:
237         source, expected = read_data("string_quotes")
238         actual = fs(source)
239         self.assertFormatEqual(expected, actual)
240         black.assert_equivalent(source, actual)
241         black.assert_stable(source, actual, line_length=ll)
242         mode = black.FileMode.NO_STRING_NORMALIZATION
243         not_normalized = fs(source, mode=mode)
244         self.assertFormatEqual(source, not_normalized)
245         black.assert_equivalent(source, not_normalized)
246         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_slices(self) -> None:
250         source, expected = read_data("slices")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, line_length=ll)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_comments(self) -> None:
258         source, expected = read_data("comments")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, line_length=ll)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_comments2(self) -> None:
266         source, expected = read_data("comments2")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, line_length=ll)
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_comments3(self) -> None:
274         source, expected = read_data("comments3")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, line_length=ll)
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_comments4(self) -> None:
282         source, expected = read_data("comments4")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, line_length=ll)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_comments5(self) -> None:
290         source, expected = read_data("comments5")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_cantfit(self) -> None:
298         source, expected = read_data("cantfit")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_import_spacing(self) -> None:
306         source, expected = read_data("import_spacing")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     @patch("black.dump_to_file", dump_to_stderr)
313     def test_composition(self) -> None:
314         source, expected = read_data("composition")
315         actual = fs(source)
316         self.assertFormatEqual(expected, actual)
317         black.assert_equivalent(source, actual)
318         black.assert_stable(source, actual, line_length=ll)
319
320     @patch("black.dump_to_file", dump_to_stderr)
321     def test_empty_lines(self) -> None:
322         source, expected = read_data("empty_lines")
323         actual = fs(source)
324         self.assertFormatEqual(expected, actual)
325         black.assert_equivalent(source, actual)
326         black.assert_stable(source, actual, line_length=ll)
327
328     @patch("black.dump_to_file", dump_to_stderr)
329     def test_string_prefixes(self) -> None:
330         source, expected = read_data("string_prefixes")
331         actual = fs(source)
332         self.assertFormatEqual(expected, actual)
333         black.assert_equivalent(source, actual)
334         black.assert_stable(source, actual, line_length=ll)
335
336     @patch("black.dump_to_file", dump_to_stderr)
337     def test_python2(self) -> None:
338         source, expected = read_data("python2")
339         actual = fs(source)
340         self.assertFormatEqual(expected, actual)
341         # black.assert_equivalent(source, actual)
342         black.assert_stable(source, actual, line_length=ll)
343
344     @patch("black.dump_to_file", dump_to_stderr)
345     def test_python2_unicode_literals(self) -> None:
346         source, expected = read_data("python2_unicode_literals")
347         actual = fs(source)
348         self.assertFormatEqual(expected, actual)
349         black.assert_stable(source, actual, line_length=ll)
350
351     @patch("black.dump_to_file", dump_to_stderr)
352     def test_stub(self) -> None:
353         mode = black.FileMode.PYI
354         source, expected = read_data("stub.pyi")
355         actual = fs(source, mode=mode)
356         self.assertFormatEqual(expected, actual)
357         black.assert_stable(source, actual, line_length=ll, mode=mode)
358
359     @patch("black.dump_to_file", dump_to_stderr)
360     def test_fmtonoff(self) -> None:
361         source, expected = read_data("fmtonoff")
362         actual = fs(source)
363         self.assertFormatEqual(expected, actual)
364         black.assert_equivalent(source, actual)
365         black.assert_stable(source, actual, line_length=ll)
366
367     @patch("black.dump_to_file", dump_to_stderr)
368     def test_remove_empty_parentheses_after_class(self) -> None:
369         source, expected = read_data("class_blank_parentheses")
370         actual = fs(source)
371         self.assertFormatEqual(expected, actual)
372         black.assert_equivalent(source, actual)
373         black.assert_stable(source, actual, line_length=ll)
374
375     @patch("black.dump_to_file", dump_to_stderr)
376     def test_new_line_between_class_and_code(self) -> None:
377         source, expected = read_data("class_methods_new_line")
378         actual = fs(source)
379         self.assertFormatEqual(expected, actual)
380         black.assert_equivalent(source, actual)
381         black.assert_stable(source, actual, line_length=ll)
382
383     def test_report_verbose(self) -> None:
384         report = black.Report(verbose=True)
385         out_lines = []
386         err_lines = []
387
388         def out(msg: str, **kwargs: Any) -> None:
389             out_lines.append(msg)
390
391         def err(msg: str, **kwargs: Any) -> None:
392             err_lines.append(msg)
393
394         with patch("black.out", out), patch("black.err", err):
395             report.done(Path("f1"), black.Changed.NO)
396             self.assertEqual(len(out_lines), 1)
397             self.assertEqual(len(err_lines), 0)
398             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
399             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
400             self.assertEqual(report.return_code, 0)
401             report.done(Path("f2"), black.Changed.YES)
402             self.assertEqual(len(out_lines), 2)
403             self.assertEqual(len(err_lines), 0)
404             self.assertEqual(out_lines[-1], "reformatted f2")
405             self.assertEqual(
406                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
407             )
408             report.done(Path("f3"), black.Changed.CACHED)
409             self.assertEqual(len(out_lines), 3)
410             self.assertEqual(len(err_lines), 0)
411             self.assertEqual(
412                 out_lines[-1], "f3 wasn't modified on disk since last run."
413             )
414             self.assertEqual(
415                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
416             )
417             self.assertEqual(report.return_code, 0)
418             report.check = True
419             self.assertEqual(report.return_code, 1)
420             report.check = False
421             report.failed(Path("e1"), "boom")
422             self.assertEqual(len(out_lines), 3)
423             self.assertEqual(len(err_lines), 1)
424             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
425             self.assertEqual(
426                 unstyle(str(report)),
427                 "1 file reformatted, 2 files left unchanged, "
428                 "1 file failed to reformat.",
429             )
430             self.assertEqual(report.return_code, 123)
431             report.done(Path("f3"), black.Changed.YES)
432             self.assertEqual(len(out_lines), 4)
433             self.assertEqual(len(err_lines), 1)
434             self.assertEqual(out_lines[-1], "reformatted f3")
435             self.assertEqual(
436                 unstyle(str(report)),
437                 "2 files reformatted, 2 files left unchanged, "
438                 "1 file failed to reformat.",
439             )
440             self.assertEqual(report.return_code, 123)
441             report.failed(Path("e2"), "boom")
442             self.assertEqual(len(out_lines), 4)
443             self.assertEqual(len(err_lines), 2)
444             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
445             self.assertEqual(
446                 unstyle(str(report)),
447                 "2 files reformatted, 2 files left unchanged, "
448                 "2 files failed to reformat.",
449             )
450             self.assertEqual(report.return_code, 123)
451             report.path_ignored(Path("wat"), "no match")
452             self.assertEqual(len(out_lines), 5)
453             self.assertEqual(len(err_lines), 2)
454             self.assertEqual(out_lines[-1], "wat ignored: no match")
455             self.assertEqual(
456                 unstyle(str(report)),
457                 "2 files reformatted, 2 files left unchanged, "
458                 "2 files failed to reformat.",
459             )
460             self.assertEqual(report.return_code, 123)
461             report.done(Path("f4"), black.Changed.NO)
462             self.assertEqual(len(out_lines), 6)
463             self.assertEqual(len(err_lines), 2)
464             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
465             self.assertEqual(
466                 unstyle(str(report)),
467                 "2 files reformatted, 3 files left unchanged, "
468                 "2 files failed to reformat.",
469             )
470             self.assertEqual(report.return_code, 123)
471             report.check = True
472             self.assertEqual(
473                 unstyle(str(report)),
474                 "2 files would be reformatted, 3 files would be left unchanged, "
475                 "2 files would fail to reformat.",
476             )
477
478     def test_report_quiet(self) -> None:
479         report = black.Report(quiet=True)
480         out_lines = []
481         err_lines = []
482
483         def out(msg: str, **kwargs: Any) -> None:
484             out_lines.append(msg)
485
486         def err(msg: str, **kwargs: Any) -> None:
487             err_lines.append(msg)
488
489         with patch("black.out", out), patch("black.err", err):
490             report.done(Path("f1"), black.Changed.NO)
491             self.assertEqual(len(out_lines), 0)
492             self.assertEqual(len(err_lines), 0)
493             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
494             self.assertEqual(report.return_code, 0)
495             report.done(Path("f2"), black.Changed.YES)
496             self.assertEqual(len(out_lines), 0)
497             self.assertEqual(len(err_lines), 0)
498             self.assertEqual(
499                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
500             )
501             report.done(Path("f3"), black.Changed.CACHED)
502             self.assertEqual(len(out_lines), 0)
503             self.assertEqual(len(err_lines), 0)
504             self.assertEqual(
505                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
506             )
507             self.assertEqual(report.return_code, 0)
508             report.check = True
509             self.assertEqual(report.return_code, 1)
510             report.check = False
511             report.failed(Path("e1"), "boom")
512             self.assertEqual(len(out_lines), 0)
513             self.assertEqual(len(err_lines), 1)
514             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
515             self.assertEqual(
516                 unstyle(str(report)),
517                 "1 file reformatted, 2 files left unchanged, "
518                 "1 file failed to reformat.",
519             )
520             self.assertEqual(report.return_code, 123)
521             report.done(Path("f3"), black.Changed.YES)
522             self.assertEqual(len(out_lines), 0)
523             self.assertEqual(len(err_lines), 1)
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "2 files reformatted, 2 files left unchanged, "
527                 "1 file failed to reformat.",
528             )
529             self.assertEqual(report.return_code, 123)
530             report.failed(Path("e2"), "boom")
531             self.assertEqual(len(out_lines), 0)
532             self.assertEqual(len(err_lines), 2)
533             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
534             self.assertEqual(
535                 unstyle(str(report)),
536                 "2 files reformatted, 2 files left unchanged, "
537                 "2 files failed to reformat.",
538             )
539             self.assertEqual(report.return_code, 123)
540             report.path_ignored(Path("wat"), "no match")
541             self.assertEqual(len(out_lines), 0)
542             self.assertEqual(len(err_lines), 2)
543             self.assertEqual(
544                 unstyle(str(report)),
545                 "2 files reformatted, 2 files left unchanged, "
546                 "2 files failed to reformat.",
547             )
548             self.assertEqual(report.return_code, 123)
549             report.done(Path("f4"), black.Changed.NO)
550             self.assertEqual(len(out_lines), 0)
551             self.assertEqual(len(err_lines), 2)
552             self.assertEqual(
553                 unstyle(str(report)),
554                 "2 files reformatted, 3 files left unchanged, "
555                 "2 files failed to reformat.",
556             )
557             self.assertEqual(report.return_code, 123)
558             report.check = True
559             self.assertEqual(
560                 unstyle(str(report)),
561                 "2 files would be reformatted, 3 files would be left unchanged, "
562                 "2 files would fail to reformat.",
563             )
564
565     def test_report_normal(self) -> None:
566         report = black.Report()
567         out_lines = []
568         err_lines = []
569
570         def out(msg: str, **kwargs: Any) -> None:
571             out_lines.append(msg)
572
573         def err(msg: str, **kwargs: Any) -> None:
574             err_lines.append(msg)
575
576         with patch("black.out", out), patch("black.err", err):
577             report.done(Path("f1"), black.Changed.NO)
578             self.assertEqual(len(out_lines), 0)
579             self.assertEqual(len(err_lines), 0)
580             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
581             self.assertEqual(report.return_code, 0)
582             report.done(Path("f2"), black.Changed.YES)
583             self.assertEqual(len(out_lines), 1)
584             self.assertEqual(len(err_lines), 0)
585             self.assertEqual(out_lines[-1], "reformatted f2")
586             self.assertEqual(
587                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
588             )
589             report.done(Path("f3"), black.Changed.CACHED)
590             self.assertEqual(len(out_lines), 1)
591             self.assertEqual(len(err_lines), 0)
592             self.assertEqual(out_lines[-1], "reformatted f2")
593             self.assertEqual(
594                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
595             )
596             self.assertEqual(report.return_code, 0)
597             report.check = True
598             self.assertEqual(report.return_code, 1)
599             report.check = False
600             report.failed(Path("e1"), "boom")
601             self.assertEqual(len(out_lines), 1)
602             self.assertEqual(len(err_lines), 1)
603             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
604             self.assertEqual(
605                 unstyle(str(report)),
606                 "1 file reformatted, 2 files left unchanged, "
607                 "1 file failed to reformat.",
608             )
609             self.assertEqual(report.return_code, 123)
610             report.done(Path("f3"), black.Changed.YES)
611             self.assertEqual(len(out_lines), 2)
612             self.assertEqual(len(err_lines), 1)
613             self.assertEqual(out_lines[-1], "reformatted f3")
614             self.assertEqual(
615                 unstyle(str(report)),
616                 "2 files reformatted, 2 files left unchanged, "
617                 "1 file failed to reformat.",
618             )
619             self.assertEqual(report.return_code, 123)
620             report.failed(Path("e2"), "boom")
621             self.assertEqual(len(out_lines), 2)
622             self.assertEqual(len(err_lines), 2)
623             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
624             self.assertEqual(
625                 unstyle(str(report)),
626                 "2 files reformatted, 2 files left unchanged, "
627                 "2 files failed to reformat.",
628             )
629             self.assertEqual(report.return_code, 123)
630             report.path_ignored(Path("wat"), "no match")
631             self.assertEqual(len(out_lines), 2)
632             self.assertEqual(len(err_lines), 2)
633             self.assertEqual(
634                 unstyle(str(report)),
635                 "2 files reformatted, 2 files left unchanged, "
636                 "2 files failed to reformat.",
637             )
638             self.assertEqual(report.return_code, 123)
639             report.done(Path("f4"), black.Changed.NO)
640             self.assertEqual(len(out_lines), 2)
641             self.assertEqual(len(err_lines), 2)
642             self.assertEqual(
643                 unstyle(str(report)),
644                 "2 files reformatted, 3 files left unchanged, "
645                 "2 files failed to reformat.",
646             )
647             self.assertEqual(report.return_code, 123)
648             report.check = True
649             self.assertEqual(
650                 unstyle(str(report)),
651                 "2 files would be reformatted, 3 files would be left unchanged, "
652                 "2 files would fail to reformat.",
653             )
654
655     def test_is_python36(self) -> None:
656         node = black.lib2to3_parse("def f(*, arg): ...\n")
657         self.assertFalse(black.is_python36(node))
658         node = black.lib2to3_parse("def f(*, arg,): ...\n")
659         self.assertTrue(black.is_python36(node))
660         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
661         self.assertTrue(black.is_python36(node))
662         source, expected = read_data("function")
663         node = black.lib2to3_parse(source)
664         self.assertTrue(black.is_python36(node))
665         node = black.lib2to3_parse(expected)
666         self.assertTrue(black.is_python36(node))
667         source, expected = read_data("expression")
668         node = black.lib2to3_parse(source)
669         self.assertFalse(black.is_python36(node))
670         node = black.lib2to3_parse(expected)
671         self.assertFalse(black.is_python36(node))
672
673     def test_get_future_imports(self) -> None:
674         node = black.lib2to3_parse("\n")
675         self.assertEqual(set(), black.get_future_imports(node))
676         node = black.lib2to3_parse("from __future__ import black\n")
677         self.assertEqual({"black"}, black.get_future_imports(node))
678         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
679         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
680         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
681         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
682         node = black.lib2to3_parse(
683             "from __future__ import multiple\nfrom __future__ import imports\n"
684         )
685         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
686         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
687         self.assertEqual({"black"}, black.get_future_imports(node))
688         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
689         self.assertEqual({"black"}, black.get_future_imports(node))
690         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
691         self.assertEqual(set(), black.get_future_imports(node))
692         node = black.lib2to3_parse("from some.module import black\n")
693         self.assertEqual(set(), black.get_future_imports(node))
694
695     def test_debug_visitor(self) -> None:
696         source, _ = read_data("debug_visitor.py")
697         expected, _ = read_data("debug_visitor.out")
698         out_lines = []
699         err_lines = []
700
701         def out(msg: str, **kwargs: Any) -> None:
702             out_lines.append(msg)
703
704         def err(msg: str, **kwargs: Any) -> None:
705             err_lines.append(msg)
706
707         with patch("black.out", out), patch("black.err", err):
708             black.DebugVisitor.show(source)
709         actual = "\n".join(out_lines) + "\n"
710         log_name = ""
711         if expected != actual:
712             log_name = black.dump_to_file(*out_lines)
713         self.assertEqual(
714             expected,
715             actual,
716             f"AST print out is different. Actual version dumped to {log_name}",
717         )
718
719     def test_format_file_contents(self) -> None:
720         empty = ""
721         with self.assertRaises(black.NothingChanged):
722             black.format_file_contents(empty, line_length=ll, fast=False)
723         just_nl = "\n"
724         with self.assertRaises(black.NothingChanged):
725             black.format_file_contents(just_nl, line_length=ll, fast=False)
726         same = "l = [1, 2, 3]\n"
727         with self.assertRaises(black.NothingChanged):
728             black.format_file_contents(same, line_length=ll, fast=False)
729         different = "l = [1,2,3]"
730         expected = same
731         actual = black.format_file_contents(different, line_length=ll, fast=False)
732         self.assertEqual(expected, actual)
733         invalid = "return if you can"
734         with self.assertRaises(ValueError) as e:
735             black.format_file_contents(invalid, line_length=ll, fast=False)
736         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
737
738     def test_endmarker(self) -> None:
739         n = black.lib2to3_parse("\n")
740         self.assertEqual(n.type, black.syms.file_input)
741         self.assertEqual(len(n.children), 1)
742         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
743
744     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
745     def test_assertFormatEqual(self) -> None:
746         out_lines = []
747         err_lines = []
748
749         def out(msg: str, **kwargs: Any) -> None:
750             out_lines.append(msg)
751
752         def err(msg: str, **kwargs: Any) -> None:
753             err_lines.append(msg)
754
755         with patch("black.out", out), patch("black.err", err):
756             with self.assertRaises(AssertionError):
757                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
758
759         out_str = "".join(out_lines)
760         self.assertTrue("Expected tree:" in out_str)
761         self.assertTrue("Actual tree:" in out_str)
762         self.assertEqual("".join(err_lines), "")
763
764     def test_cache_broken_file(self) -> None:
765         mode = black.FileMode.AUTO_DETECT
766         with cache_dir() as workspace:
767             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
768             with cache_file.open("w") as fobj:
769                 fobj.write("this is not a pickle")
770             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
771             src = (workspace / "test.py").resolve()
772             with src.open("w") as fobj:
773                 fobj.write("print('hello')")
774             result = CliRunner().invoke(black.main, [str(src)])
775             self.assertEqual(result.exit_code, 0)
776             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
777             self.assertIn(src, cache)
778
779     def test_cache_single_file_already_cached(self) -> None:
780         mode = black.FileMode.AUTO_DETECT
781         with cache_dir() as workspace:
782             src = (workspace / "test.py").resolve()
783             with src.open("w") as fobj:
784                 fobj.write("print('hello')")
785             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
786             result = CliRunner().invoke(black.main, [str(src)])
787             self.assertEqual(result.exit_code, 0)
788             with src.open("r") as fobj:
789                 self.assertEqual(fobj.read(), "print('hello')")
790
791     @event_loop(close=False)
792     def test_cache_multiple_files(self) -> None:
793         mode = black.FileMode.AUTO_DETECT
794         with cache_dir() as workspace, patch(
795             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
796         ):
797             one = (workspace / "one.py").resolve()
798             with one.open("w") as fobj:
799                 fobj.write("print('hello')")
800             two = (workspace / "two.py").resolve()
801             with two.open("w") as fobj:
802                 fobj.write("print('hello')")
803             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
804             result = CliRunner().invoke(black.main, [str(workspace)])
805             self.assertEqual(result.exit_code, 0)
806             with one.open("r") as fobj:
807                 self.assertEqual(fobj.read(), "print('hello')")
808             with two.open("r") as fobj:
809                 self.assertEqual(fobj.read(), 'print("hello")\n')
810             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
811             self.assertIn(one, cache)
812             self.assertIn(two, cache)
813
814     def test_no_cache_when_writeback_diff(self) -> None:
815         mode = black.FileMode.AUTO_DETECT
816         with cache_dir() as workspace:
817             src = (workspace / "test.py").resolve()
818             with src.open("w") as fobj:
819                 fobj.write("print('hello')")
820             result = CliRunner().invoke(black.main, [str(src), "--diff"])
821             self.assertEqual(result.exit_code, 0)
822             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
823             self.assertFalse(cache_file.exists())
824
825     def test_no_cache_when_stdin(self) -> None:
826         mode = black.FileMode.AUTO_DETECT
827         with cache_dir():
828             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
829             self.assertEqual(result.exit_code, 0)
830             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
831             self.assertFalse(cache_file.exists())
832
833     def test_read_cache_no_cachefile(self) -> None:
834         mode = black.FileMode.AUTO_DETECT
835         with cache_dir():
836             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
837
838     def test_write_cache_read_cache(self) -> None:
839         mode = black.FileMode.AUTO_DETECT
840         with cache_dir() as workspace:
841             src = (workspace / "test.py").resolve()
842             src.touch()
843             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
844             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
845             self.assertIn(src, cache)
846             self.assertEqual(cache[src], black.get_cache_info(src))
847
848     def test_filter_cached(self) -> None:
849         with TemporaryDirectory() as workspace:
850             path = Path(workspace)
851             uncached = (path / "uncached").resolve()
852             cached = (path / "cached").resolve()
853             cached_but_changed = (path / "changed").resolve()
854             uncached.touch()
855             cached.touch()
856             cached_but_changed.touch()
857             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
858             todo, done = black.filter_cached(
859                 cache, [uncached, cached, cached_but_changed]
860             )
861             self.assertEqual(todo, [uncached, cached_but_changed])
862             self.assertEqual(done, [cached])
863
864     def test_write_cache_creates_directory_if_needed(self) -> None:
865         mode = black.FileMode.AUTO_DETECT
866         with cache_dir(exists=False) as workspace:
867             self.assertFalse(workspace.exists())
868             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
869             self.assertTrue(workspace.exists())
870
871     @event_loop(close=False)
872     def test_failed_formatting_does_not_get_cached(self) -> None:
873         mode = black.FileMode.AUTO_DETECT
874         with cache_dir() as workspace, patch(
875             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
876         ):
877             failing = (workspace / "failing.py").resolve()
878             with failing.open("w") as fobj:
879                 fobj.write("not actually python")
880             clean = (workspace / "clean.py").resolve()
881             with clean.open("w") as fobj:
882                 fobj.write('print("hello")\n')
883             result = CliRunner().invoke(black.main, [str(workspace)])
884             self.assertEqual(result.exit_code, 123)
885             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
886             self.assertNotIn(failing, cache)
887             self.assertIn(clean, cache)
888
889     def test_write_cache_write_fail(self) -> None:
890         mode = black.FileMode.AUTO_DETECT
891         with cache_dir(), patch.object(Path, "open") as mock:
892             mock.side_effect = OSError
893             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
894
895     @event_loop(close=False)
896     def test_check_diff_use_together(self) -> None:
897         with cache_dir():
898             # Files which will be reformatted.
899             src1 = (THIS_DIR / "string_quotes.py").resolve()
900             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
901             self.assertEqual(result.exit_code, 1)
902
903             # Files which will not be reformatted.
904             src2 = (THIS_DIR / "composition.py").resolve()
905             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
906             self.assertEqual(result.exit_code, 0)
907
908             # Multi file command.
909             result = CliRunner().invoke(
910                 black.main, [str(src1), str(src2), "--diff", "--check"]
911             )
912             self.assertEqual(result.exit_code, 1, result.output)
913
914     def test_no_files(self) -> None:
915         with cache_dir():
916             # Without an argument, black exits with error code 0.
917             result = CliRunner().invoke(black.main, [])
918             self.assertEqual(result.exit_code, 0)
919
920     def test_broken_symlink(self) -> None:
921         with cache_dir() as workspace:
922             symlink = workspace / "broken_link.py"
923             try:
924                 symlink.symlink_to("nonexistent.py")
925             except OSError as e:
926                 self.skipTest(f"Can't create symlinks: {e}")
927             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
928             self.assertEqual(result.exit_code, 0)
929
930     def test_read_cache_line_lengths(self) -> None:
931         mode = black.FileMode.AUTO_DETECT
932         with cache_dir() as workspace:
933             path = (workspace / "file.py").resolve()
934             path.touch()
935             black.write_cache({}, [path], 1, mode)
936             one = black.read_cache(1, mode)
937             self.assertIn(path, one)
938             two = black.read_cache(2, mode)
939             self.assertNotIn(path, two)
940
941     def test_single_file_force_pyi(self) -> None:
942         reg_mode = black.FileMode.AUTO_DETECT
943         pyi_mode = black.FileMode.PYI
944         contents, expected = read_data("force_pyi")
945         with cache_dir() as workspace:
946             path = (workspace / "file.py").resolve()
947             with open(path, "w") as fh:
948                 fh.write(contents)
949             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
950             self.assertEqual(result.exit_code, 0)
951             with open(path, "r") as fh:
952                 actual = fh.read()
953             # verify cache with --pyi is separate
954             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
955             self.assertIn(path, pyi_cache)
956             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
957             self.assertNotIn(path, normal_cache)
958         self.assertEqual(actual, expected)
959
960     @event_loop(close=False)
961     def test_multi_file_force_pyi(self) -> None:
962         reg_mode = black.FileMode.AUTO_DETECT
963         pyi_mode = black.FileMode.PYI
964         contents, expected = read_data("force_pyi")
965         with cache_dir() as workspace:
966             paths = [
967                 (workspace / "file1.py").resolve(),
968                 (workspace / "file2.py").resolve(),
969             ]
970             for path in paths:
971                 with open(path, "w") as fh:
972                     fh.write(contents)
973             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
974             self.assertEqual(result.exit_code, 0)
975             for path in paths:
976                 with open(path, "r") as fh:
977                     actual = fh.read()
978                 self.assertEqual(actual, expected)
979             # verify cache with --pyi is separate
980             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
981             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
982             for path in paths:
983                 self.assertIn(path, pyi_cache)
984                 self.assertNotIn(path, normal_cache)
985
986     def test_pipe_force_pyi(self) -> None:
987         source, expected = read_data("force_pyi")
988         result = CliRunner().invoke(black.main, ["-", "-q", "--pyi"], input=source)
989         self.assertEqual(result.exit_code, 0)
990         actual = result.output
991         self.assertFormatEqual(actual, expected)
992
993     def test_single_file_force_py36(self) -> None:
994         reg_mode = black.FileMode.AUTO_DETECT
995         py36_mode = black.FileMode.PYTHON36
996         source, expected = read_data("force_py36")
997         with cache_dir() as workspace:
998             path = (workspace / "file.py").resolve()
999             with open(path, "w") as fh:
1000                 fh.write(source)
1001             result = CliRunner().invoke(black.main, [str(path), "--py36"])
1002             self.assertEqual(result.exit_code, 0)
1003             with open(path, "r") as fh:
1004                 actual = fh.read()
1005             # verify cache with --py36 is separate
1006             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1007             self.assertIn(path, py36_cache)
1008             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1009             self.assertNotIn(path, normal_cache)
1010         self.assertEqual(actual, expected)
1011
1012     @event_loop(close=False)
1013     def test_multi_file_force_py36(self) -> None:
1014         reg_mode = black.FileMode.AUTO_DETECT
1015         py36_mode = black.FileMode.PYTHON36
1016         source, expected = read_data("force_py36")
1017         with cache_dir() as workspace:
1018             paths = [
1019                 (workspace / "file1.py").resolve(),
1020                 (workspace / "file2.py").resolve(),
1021             ]
1022             for path in paths:
1023                 with open(path, "w") as fh:
1024                     fh.write(source)
1025             result = CliRunner().invoke(
1026                 black.main, [str(p) for p in paths] + ["--py36"]
1027             )
1028             self.assertEqual(result.exit_code, 0)
1029             for path in paths:
1030                 with open(path, "r") as fh:
1031                     actual = fh.read()
1032                 self.assertEqual(actual, expected)
1033             # verify cache with --py36 is separate
1034             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1035             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1036             for path in paths:
1037                 self.assertIn(path, pyi_cache)
1038                 self.assertNotIn(path, normal_cache)
1039
1040     def test_pipe_force_py36(self) -> None:
1041         source, expected = read_data("force_py36")
1042         result = CliRunner().invoke(black.main, ["-", "-q", "--py36"], input=source)
1043         self.assertEqual(result.exit_code, 0)
1044         actual = result.output
1045         self.assertFormatEqual(actual, expected)
1046
1047     def test_include_exclude(self) -> None:
1048         path = THIS_DIR / "include_exclude_tests"
1049         include = re.compile(r"\.pyi?$")
1050         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1051         report = black.Report()
1052         sources: List[Path] = []
1053         expected = [
1054             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.py"),
1055             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.pyi"),
1056         ]
1057         this_abs = THIS_DIR.resolve()
1058         sources.extend(
1059             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1060         )
1061         self.assertEqual(sorted(expected), sorted(sources))
1062
1063     def test_empty_include(self) -> None:
1064         path = THIS_DIR / "include_exclude_tests"
1065         report = black.Report()
1066         empty = re.compile(r"")
1067         sources: List[Path] = []
1068         expected = [
1069             Path(path / "b/exclude/a.pie"),
1070             Path(path / "b/exclude/a.py"),
1071             Path(path / "b/exclude/a.pyi"),
1072             Path(path / "b/dont_exclude/a.pie"),
1073             Path(path / "b/dont_exclude/a.py"),
1074             Path(path / "b/dont_exclude/a.pyi"),
1075             Path(path / "b/.definitely_exclude/a.pie"),
1076             Path(path / "b/.definitely_exclude/a.py"),
1077             Path(path / "b/.definitely_exclude/a.pyi"),
1078         ]
1079         this_abs = THIS_DIR.resolve()
1080         sources.extend(
1081             black.gen_python_files_in_dir(
1082                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1083             )
1084         )
1085         self.assertEqual(sorted(expected), sorted(sources))
1086
1087     def test_empty_exclude(self) -> None:
1088         path = THIS_DIR / "include_exclude_tests"
1089         report = black.Report()
1090         empty = re.compile(r"")
1091         sources: List[Path] = []
1092         expected = [
1093             Path(path / "b/dont_exclude/a.py"),
1094             Path(path / "b/dont_exclude/a.pyi"),
1095             Path(path / "b/exclude/a.py"),
1096             Path(path / "b/exclude/a.pyi"),
1097             Path(path / "b/.definitely_exclude/a.py"),
1098             Path(path / "b/.definitely_exclude/a.pyi"),
1099         ]
1100         this_abs = THIS_DIR.resolve()
1101         sources.extend(
1102             black.gen_python_files_in_dir(
1103                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1104             )
1105         )
1106         self.assertEqual(sorted(expected), sorted(sources))
1107
1108     def test_invalid_include_exclude(self) -> None:
1109         for option in ["--include", "--exclude"]:
1110             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
1111             self.assertEqual(result.exit_code, 2)
1112
1113     def test_preserves_line_endings(self) -> None:
1114         with TemporaryDirectory() as workspace:
1115             test_file = Path(workspace) / "test.py"
1116             for nl in ["\n", "\r\n"]:
1117                 contents = nl.join(["def f(  ):", "    pass"])
1118                 test_file.write_bytes(contents.encode())
1119                 ff(test_file, write_back=black.WriteBack.YES)
1120                 updated_contents: bytes = test_file.read_bytes()
1121                 self.assertIn(nl.encode(), updated_contents)  # type: ignore
1122                 if nl == "\n":
1123                     self.assertNotIn(b"\r\n", updated_contents)  # type: ignore
1124
1125
1126 if __name__ == "__main__":
1127     unittest.main()