]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add `--verbose` and report excluded paths in it, too
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import StringIO
7 import os
8 from pathlib import Path
9 import sys
10 from tempfile import TemporaryDirectory
11 from typing import Any, List, Tuple, Iterator
12 import unittest
13 from unittest.mock import patch
14 import re
15
16 from click import unstyle
17 from click.testing import CliRunner
18
19 import black
20
21 ll = 88
22 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
23 fs = partial(black.format_str, line_length=ll)
24 THIS_FILE = Path(__file__)
25 THIS_DIR = THIS_FILE.parent
26 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
27
28
29 def dump_to_stderr(*output: str) -> str:
30     return "\n" + "\n".join(output) + "\n"
31
32
33 def read_data(name: str) -> Tuple[str, str]:
34     """read_data('test_name') -> 'input', 'output'"""
35     if not name.endswith((".py", ".pyi", ".out", ".diff")):
36         name += ".py"
37     _input: List[str] = []
38     _output: List[str] = []
39     with open(THIS_DIR / name, "r", encoding="utf8") as test:
40         lines = test.readlines()
41     result = _input
42     for line in lines:
43         line = line.replace(EMPTY_LINE, "")
44         if line.rstrip() == "# output":
45             result = _output
46             continue
47
48         result.append(line)
49     if _input and not _output:
50         # If there's no output marker, treat the entire file as already pre-formatted.
51         _output = _input[:]
52     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
53
54
55 @contextmanager
56 def cache_dir(exists: bool = True) -> Iterator[Path]:
57     with TemporaryDirectory() as workspace:
58         cache_dir = Path(workspace)
59         if not exists:
60             cache_dir = cache_dir / "new"
61         with patch("black.CACHE_DIR", cache_dir):
62             yield cache_dir
63
64
65 @contextmanager
66 def event_loop(close: bool) -> Iterator[None]:
67     policy = asyncio.get_event_loop_policy()
68     old_loop = policy.get_event_loop()
69     loop = policy.new_event_loop()
70     asyncio.set_event_loop(loop)
71     try:
72         yield
73
74     finally:
75         policy.set_event_loop(old_loop)
76         if close:
77             loop.close()
78
79
80 class BlackTestCase(unittest.TestCase):
81     maxDiff = None
82
83     def assertFormatEqual(self, expected: str, actual: str) -> None:
84         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
85             bdv: black.DebugVisitor[Any]
86             black.out("Expected tree:", fg="green")
87             try:
88                 exp_node = black.lib2to3_parse(expected)
89                 bdv = black.DebugVisitor()
90                 list(bdv.visit(exp_node))
91             except Exception as ve:
92                 black.err(str(ve))
93             black.out("Actual tree:", fg="red")
94             try:
95                 exp_node = black.lib2to3_parse(actual)
96                 bdv = black.DebugVisitor()
97                 list(bdv.visit(exp_node))
98             except Exception as ve:
99                 black.err(str(ve))
100         self.assertEqual(expected, actual)
101
102     @patch("black.dump_to_file", dump_to_stderr)
103     def test_self(self) -> None:
104         source, expected = read_data("test_black")
105         actual = fs(source)
106         self.assertFormatEqual(expected, actual)
107         black.assert_equivalent(source, actual)
108         black.assert_stable(source, actual, line_length=ll)
109         self.assertFalse(ff(THIS_FILE))
110
111     @patch("black.dump_to_file", dump_to_stderr)
112     def test_black(self) -> None:
113         source, expected = read_data("../black")
114         actual = fs(source)
115         self.assertFormatEqual(expected, actual)
116         black.assert_equivalent(source, actual)
117         black.assert_stable(source, actual, line_length=ll)
118         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
119
120     def test_piping(self) -> None:
121         source, expected = read_data("../black")
122         hold_stdin, hold_stdout = sys.stdin, sys.stdout
123         try:
124             sys.stdin, sys.stdout = StringIO(source), StringIO()
125             sys.stdin.name = "<stdin>"
126             black.format_stdin_to_stdout(
127                 line_length=ll, fast=True, write_back=black.WriteBack.YES
128             )
129             sys.stdout.seek(0)
130             actual = sys.stdout.read()
131         finally:
132             sys.stdin, sys.stdout = hold_stdin, hold_stdout
133         self.assertFormatEqual(expected, actual)
134         black.assert_equivalent(source, actual)
135         black.assert_stable(source, actual, line_length=ll)
136
137     def test_piping_diff(self) -> None:
138         source, _ = read_data("expression.py")
139         expected, _ = read_data("expression.diff")
140         hold_stdin, hold_stdout = sys.stdin, sys.stdout
141         try:
142             sys.stdin, sys.stdout = StringIO(source), StringIO()
143             sys.stdin.name = "<stdin>"
144             black.format_stdin_to_stdout(
145                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
146             )
147             sys.stdout.seek(0)
148             actual = sys.stdout.read()
149         finally:
150             sys.stdin, sys.stdout = hold_stdin, hold_stdout
151         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
152         self.assertEqual(expected, actual)
153
154     @patch("black.dump_to_file", dump_to_stderr)
155     def test_setup(self) -> None:
156         source, expected = read_data("../setup")
157         actual = fs(source)
158         self.assertFormatEqual(expected, actual)
159         black.assert_equivalent(source, actual)
160         black.assert_stable(source, actual, line_length=ll)
161         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
162
163     @patch("black.dump_to_file", dump_to_stderr)
164     def test_function(self) -> None:
165         source, expected = read_data("function")
166         actual = fs(source)
167         self.assertFormatEqual(expected, actual)
168         black.assert_equivalent(source, actual)
169         black.assert_stable(source, actual, line_length=ll)
170
171     @patch("black.dump_to_file", dump_to_stderr)
172     def test_function2(self) -> None:
173         source, expected = read_data("function2")
174         actual = fs(source)
175         self.assertFormatEqual(expected, actual)
176         black.assert_equivalent(source, actual)
177         black.assert_stable(source, actual, line_length=ll)
178
179     @patch("black.dump_to_file", dump_to_stderr)
180     def test_expression(self) -> None:
181         source, expected = read_data("expression")
182         actual = fs(source)
183         self.assertFormatEqual(expected, actual)
184         black.assert_equivalent(source, actual)
185         black.assert_stable(source, actual, line_length=ll)
186
187     def test_expression_ff(self) -> None:
188         source, expected = read_data("expression")
189         tmp_file = Path(black.dump_to_file(source))
190         try:
191             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
192             with open(tmp_file, encoding="utf8") as f:
193                 actual = f.read()
194         finally:
195             os.unlink(tmp_file)
196         self.assertFormatEqual(expected, actual)
197         with patch("black.dump_to_file", dump_to_stderr):
198             black.assert_equivalent(source, actual)
199             black.assert_stable(source, actual, line_length=ll)
200
201     def test_expression_diff(self) -> None:
202         source, _ = read_data("expression.py")
203         expected, _ = read_data("expression.diff")
204         tmp_file = Path(black.dump_to_file(source))
205         hold_stdout = sys.stdout
206         try:
207             sys.stdout = StringIO()
208             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
209             sys.stdout.seek(0)
210             actual = sys.stdout.read()
211             actual = actual.replace(str(tmp_file), "<stdin>")
212         finally:
213             sys.stdout = hold_stdout
214             os.unlink(tmp_file)
215         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
216         if expected != actual:
217             dump = black.dump_to_file(actual)
218             msg = (
219                 f"Expected diff isn't equal to the actual. If you made changes "
220                 f"to expression.py and this is an anticipated difference, "
221                 f"overwrite tests/expression.diff with {dump}"
222             )
223             self.assertEqual(expected, actual, msg)
224
225     @patch("black.dump_to_file", dump_to_stderr)
226     def test_fstring(self) -> None:
227         source, expected = read_data("fstring")
228         actual = fs(source)
229         self.assertFormatEqual(expected, actual)
230         black.assert_equivalent(source, actual)
231         black.assert_stable(source, actual, line_length=ll)
232
233     @patch("black.dump_to_file", dump_to_stderr)
234     def test_string_quotes(self) -> None:
235         source, expected = read_data("string_quotes")
236         actual = fs(source)
237         self.assertFormatEqual(expected, actual)
238         black.assert_equivalent(source, actual)
239         black.assert_stable(source, actual, line_length=ll)
240         mode = black.FileMode.NO_STRING_NORMALIZATION
241         not_normalized = fs(source, mode=mode)
242         self.assertFormatEqual(source, not_normalized)
243         black.assert_equivalent(source, not_normalized)
244         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
245
246     @patch("black.dump_to_file", dump_to_stderr)
247     def test_slices(self) -> None:
248         source, expected = read_data("slices")
249         actual = fs(source)
250         self.assertFormatEqual(expected, actual)
251         black.assert_equivalent(source, actual)
252         black.assert_stable(source, actual, line_length=ll)
253
254     @patch("black.dump_to_file", dump_to_stderr)
255     def test_comments(self) -> None:
256         source, expected = read_data("comments")
257         actual = fs(source)
258         self.assertFormatEqual(expected, actual)
259         black.assert_equivalent(source, actual)
260         black.assert_stable(source, actual, line_length=ll)
261
262     @patch("black.dump_to_file", dump_to_stderr)
263     def test_comments2(self) -> None:
264         source, expected = read_data("comments2")
265         actual = fs(source)
266         self.assertFormatEqual(expected, actual)
267         black.assert_equivalent(source, actual)
268         black.assert_stable(source, actual, line_length=ll)
269
270     @patch("black.dump_to_file", dump_to_stderr)
271     def test_comments3(self) -> None:
272         source, expected = read_data("comments3")
273         actual = fs(source)
274         self.assertFormatEqual(expected, actual)
275         black.assert_equivalent(source, actual)
276         black.assert_stable(source, actual, line_length=ll)
277
278     @patch("black.dump_to_file", dump_to_stderr)
279     def test_comments4(self) -> None:
280         source, expected = read_data("comments4")
281         actual = fs(source)
282         self.assertFormatEqual(expected, actual)
283         black.assert_equivalent(source, actual)
284         black.assert_stable(source, actual, line_length=ll)
285
286     @patch("black.dump_to_file", dump_to_stderr)
287     def test_comments5(self) -> None:
288         source, expected = read_data("comments5")
289         actual = fs(source)
290         self.assertFormatEqual(expected, actual)
291         black.assert_equivalent(source, actual)
292         black.assert_stable(source, actual, line_length=ll)
293
294     @patch("black.dump_to_file", dump_to_stderr)
295     def test_cantfit(self) -> None:
296         source, expected = read_data("cantfit")
297         actual = fs(source)
298         self.assertFormatEqual(expected, actual)
299         black.assert_equivalent(source, actual)
300         black.assert_stable(source, actual, line_length=ll)
301
302     @patch("black.dump_to_file", dump_to_stderr)
303     def test_import_spacing(self) -> None:
304         source, expected = read_data("import_spacing")
305         actual = fs(source)
306         self.assertFormatEqual(expected, actual)
307         black.assert_equivalent(source, actual)
308         black.assert_stable(source, actual, line_length=ll)
309
310     @patch("black.dump_to_file", dump_to_stderr)
311     def test_composition(self) -> None:
312         source, expected = read_data("composition")
313         actual = fs(source)
314         self.assertFormatEqual(expected, actual)
315         black.assert_equivalent(source, actual)
316         black.assert_stable(source, actual, line_length=ll)
317
318     @patch("black.dump_to_file", dump_to_stderr)
319     def test_empty_lines(self) -> None:
320         source, expected = read_data("empty_lines")
321         actual = fs(source)
322         self.assertFormatEqual(expected, actual)
323         black.assert_equivalent(source, actual)
324         black.assert_stable(source, actual, line_length=ll)
325
326     @patch("black.dump_to_file", dump_to_stderr)
327     def test_string_prefixes(self) -> None:
328         source, expected = read_data("string_prefixes")
329         actual = fs(source)
330         self.assertFormatEqual(expected, actual)
331         black.assert_equivalent(source, actual)
332         black.assert_stable(source, actual, line_length=ll)
333
334     @patch("black.dump_to_file", dump_to_stderr)
335     def test_python2(self) -> None:
336         source, expected = read_data("python2")
337         actual = fs(source)
338         self.assertFormatEqual(expected, actual)
339         # black.assert_equivalent(source, actual)
340         black.assert_stable(source, actual, line_length=ll)
341
342     @patch("black.dump_to_file", dump_to_stderr)
343     def test_python2_unicode_literals(self) -> None:
344         source, expected = read_data("python2_unicode_literals")
345         actual = fs(source)
346         self.assertFormatEqual(expected, actual)
347         black.assert_stable(source, actual, line_length=ll)
348
349     @patch("black.dump_to_file", dump_to_stderr)
350     def test_stub(self) -> None:
351         mode = black.FileMode.PYI
352         source, expected = read_data("stub.pyi")
353         actual = fs(source, mode=mode)
354         self.assertFormatEqual(expected, actual)
355         black.assert_stable(source, actual, line_length=ll, mode=mode)
356
357     @patch("black.dump_to_file", dump_to_stderr)
358     def test_fmtonoff(self) -> None:
359         source, expected = read_data("fmtonoff")
360         actual = fs(source)
361         self.assertFormatEqual(expected, actual)
362         black.assert_equivalent(source, actual)
363         black.assert_stable(source, actual, line_length=ll)
364
365     @patch("black.dump_to_file", dump_to_stderr)
366     def test_remove_empty_parentheses_after_class(self) -> None:
367         source, expected = read_data("class_blank_parentheses")
368         actual = fs(source)
369         self.assertFormatEqual(expected, actual)
370         black.assert_equivalent(source, actual)
371         black.assert_stable(source, actual, line_length=ll)
372
373     @patch("black.dump_to_file", dump_to_stderr)
374     def test_new_line_between_class_and_code(self) -> None:
375         source, expected = read_data("class_methods_new_line")
376         actual = fs(source)
377         self.assertFormatEqual(expected, actual)
378         black.assert_equivalent(source, actual)
379         black.assert_stable(source, actual, line_length=ll)
380
381     def test_report_verbose(self) -> None:
382         report = black.Report(verbose=True)
383         out_lines = []
384         err_lines = []
385
386         def out(msg: str, **kwargs: Any) -> None:
387             out_lines.append(msg)
388
389         def err(msg: str, **kwargs: Any) -> None:
390             err_lines.append(msg)
391
392         with patch("black.out", out), patch("black.err", err):
393             report.done(Path("f1"), black.Changed.NO)
394             self.assertEqual(len(out_lines), 1)
395             self.assertEqual(len(err_lines), 0)
396             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
397             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
398             self.assertEqual(report.return_code, 0)
399             report.done(Path("f2"), black.Changed.YES)
400             self.assertEqual(len(out_lines), 2)
401             self.assertEqual(len(err_lines), 0)
402             self.assertEqual(out_lines[-1], "reformatted f2")
403             self.assertEqual(
404                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
405             )
406             report.done(Path("f3"), black.Changed.CACHED)
407             self.assertEqual(len(out_lines), 3)
408             self.assertEqual(len(err_lines), 0)
409             self.assertEqual(
410                 out_lines[-1], "f3 wasn't modified on disk since last run."
411             )
412             self.assertEqual(
413                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
414             )
415             self.assertEqual(report.return_code, 0)
416             report.check = True
417             self.assertEqual(report.return_code, 1)
418             report.check = False
419             report.failed(Path("e1"), "boom")
420             self.assertEqual(len(out_lines), 3)
421             self.assertEqual(len(err_lines), 1)
422             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
423             self.assertEqual(
424                 unstyle(str(report)),
425                 "1 file reformatted, 2 files left unchanged, "
426                 "1 file failed to reformat.",
427             )
428             self.assertEqual(report.return_code, 123)
429             report.done(Path("f3"), black.Changed.YES)
430             self.assertEqual(len(out_lines), 4)
431             self.assertEqual(len(err_lines), 1)
432             self.assertEqual(out_lines[-1], "reformatted f3")
433             self.assertEqual(
434                 unstyle(str(report)),
435                 "2 files reformatted, 2 files left unchanged, "
436                 "1 file failed to reformat.",
437             )
438             self.assertEqual(report.return_code, 123)
439             report.failed(Path("e2"), "boom")
440             self.assertEqual(len(out_lines), 4)
441             self.assertEqual(len(err_lines), 2)
442             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
443             self.assertEqual(
444                 unstyle(str(report)),
445                 "2 files reformatted, 2 files left unchanged, "
446                 "2 files failed to reformat.",
447             )
448             self.assertEqual(report.return_code, 123)
449             report.path_ignored(Path("wat"), "no match")
450             self.assertEqual(len(out_lines), 5)
451             self.assertEqual(len(err_lines), 2)
452             self.assertEqual(out_lines[-1], "wat ignored: no match")
453             self.assertEqual(
454                 unstyle(str(report)),
455                 "2 files reformatted, 2 files left unchanged, "
456                 "2 files failed to reformat.",
457             )
458             self.assertEqual(report.return_code, 123)
459             report.done(Path("f4"), black.Changed.NO)
460             self.assertEqual(len(out_lines), 6)
461             self.assertEqual(len(err_lines), 2)
462             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
463             self.assertEqual(
464                 unstyle(str(report)),
465                 "2 files reformatted, 3 files left unchanged, "
466                 "2 files failed to reformat.",
467             )
468             self.assertEqual(report.return_code, 123)
469             report.check = True
470             self.assertEqual(
471                 unstyle(str(report)),
472                 "2 files would be reformatted, 3 files would be left unchanged, "
473                 "2 files would fail to reformat.",
474             )
475
476     def test_report_quiet(self) -> None:
477         report = black.Report(quiet=True)
478         out_lines = []
479         err_lines = []
480
481         def out(msg: str, **kwargs: Any) -> None:
482             out_lines.append(msg)
483
484         def err(msg: str, **kwargs: Any) -> None:
485             err_lines.append(msg)
486
487         with patch("black.out", out), patch("black.err", err):
488             report.done(Path("f1"), black.Changed.NO)
489             self.assertEqual(len(out_lines), 0)
490             self.assertEqual(len(err_lines), 0)
491             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
492             self.assertEqual(report.return_code, 0)
493             report.done(Path("f2"), black.Changed.YES)
494             self.assertEqual(len(out_lines), 0)
495             self.assertEqual(len(err_lines), 0)
496             self.assertEqual(
497                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
498             )
499             report.done(Path("f3"), black.Changed.CACHED)
500             self.assertEqual(len(out_lines), 0)
501             self.assertEqual(len(err_lines), 0)
502             self.assertEqual(
503                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
504             )
505             self.assertEqual(report.return_code, 0)
506             report.check = True
507             self.assertEqual(report.return_code, 1)
508             report.check = False
509             report.failed(Path("e1"), "boom")
510             self.assertEqual(len(out_lines), 0)
511             self.assertEqual(len(err_lines), 1)
512             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
513             self.assertEqual(
514                 unstyle(str(report)),
515                 "1 file reformatted, 2 files left unchanged, "
516                 "1 file failed to reformat.",
517             )
518             self.assertEqual(report.return_code, 123)
519             report.done(Path("f3"), black.Changed.YES)
520             self.assertEqual(len(out_lines), 0)
521             self.assertEqual(len(err_lines), 1)
522             self.assertEqual(
523                 unstyle(str(report)),
524                 "2 files reformatted, 2 files left unchanged, "
525                 "1 file failed to reformat.",
526             )
527             self.assertEqual(report.return_code, 123)
528             report.failed(Path("e2"), "boom")
529             self.assertEqual(len(out_lines), 0)
530             self.assertEqual(len(err_lines), 2)
531             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
532             self.assertEqual(
533                 unstyle(str(report)),
534                 "2 files reformatted, 2 files left unchanged, "
535                 "2 files failed to reformat.",
536             )
537             self.assertEqual(report.return_code, 123)
538             report.path_ignored(Path("wat"), "no match")
539             self.assertEqual(len(out_lines), 0)
540             self.assertEqual(len(err_lines), 2)
541             self.assertEqual(
542                 unstyle(str(report)),
543                 "2 files reformatted, 2 files left unchanged, "
544                 "2 files failed to reformat.",
545             )
546             self.assertEqual(report.return_code, 123)
547             report.done(Path("f4"), black.Changed.NO)
548             self.assertEqual(len(out_lines), 0)
549             self.assertEqual(len(err_lines), 2)
550             self.assertEqual(
551                 unstyle(str(report)),
552                 "2 files reformatted, 3 files left unchanged, "
553                 "2 files failed to reformat.",
554             )
555             self.assertEqual(report.return_code, 123)
556             report.check = True
557             self.assertEqual(
558                 unstyle(str(report)),
559                 "2 files would be reformatted, 3 files would be left unchanged, "
560                 "2 files would fail to reformat.",
561             )
562
563     def test_report_normal(self) -> None:
564         report = black.Report()
565         out_lines = []
566         err_lines = []
567
568         def out(msg: str, **kwargs: Any) -> None:
569             out_lines.append(msg)
570
571         def err(msg: str, **kwargs: Any) -> None:
572             err_lines.append(msg)
573
574         with patch("black.out", out), patch("black.err", err):
575             report.done(Path("f1"), black.Changed.NO)
576             self.assertEqual(len(out_lines), 0)
577             self.assertEqual(len(err_lines), 0)
578             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
579             self.assertEqual(report.return_code, 0)
580             report.done(Path("f2"), black.Changed.YES)
581             self.assertEqual(len(out_lines), 1)
582             self.assertEqual(len(err_lines), 0)
583             self.assertEqual(out_lines[-1], "reformatted f2")
584             self.assertEqual(
585                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
586             )
587             report.done(Path("f3"), black.Changed.CACHED)
588             self.assertEqual(len(out_lines), 1)
589             self.assertEqual(len(err_lines), 0)
590             self.assertEqual(out_lines[-1], "reformatted f2")
591             self.assertEqual(
592                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
593             )
594             self.assertEqual(report.return_code, 0)
595             report.check = True
596             self.assertEqual(report.return_code, 1)
597             report.check = False
598             report.failed(Path("e1"), "boom")
599             self.assertEqual(len(out_lines), 1)
600             self.assertEqual(len(err_lines), 1)
601             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
602             self.assertEqual(
603                 unstyle(str(report)),
604                 "1 file reformatted, 2 files left unchanged, "
605                 "1 file failed to reformat.",
606             )
607             self.assertEqual(report.return_code, 123)
608             report.done(Path("f3"), black.Changed.YES)
609             self.assertEqual(len(out_lines), 2)
610             self.assertEqual(len(err_lines), 1)
611             self.assertEqual(out_lines[-1], "reformatted f3")
612             self.assertEqual(
613                 unstyle(str(report)),
614                 "2 files reformatted, 2 files left unchanged, "
615                 "1 file failed to reformat.",
616             )
617             self.assertEqual(report.return_code, 123)
618             report.failed(Path("e2"), "boom")
619             self.assertEqual(len(out_lines), 2)
620             self.assertEqual(len(err_lines), 2)
621             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
622             self.assertEqual(
623                 unstyle(str(report)),
624                 "2 files reformatted, 2 files left unchanged, "
625                 "2 files failed to reformat.",
626             )
627             self.assertEqual(report.return_code, 123)
628             report.path_ignored(Path("wat"), "no match")
629             self.assertEqual(len(out_lines), 2)
630             self.assertEqual(len(err_lines), 2)
631             self.assertEqual(
632                 unstyle(str(report)),
633                 "2 files reformatted, 2 files left unchanged, "
634                 "2 files failed to reformat.",
635             )
636             self.assertEqual(report.return_code, 123)
637             report.done(Path("f4"), black.Changed.NO)
638             self.assertEqual(len(out_lines), 2)
639             self.assertEqual(len(err_lines), 2)
640             self.assertEqual(
641                 unstyle(str(report)),
642                 "2 files reformatted, 3 files left unchanged, "
643                 "2 files failed to reformat.",
644             )
645             self.assertEqual(report.return_code, 123)
646             report.check = True
647             self.assertEqual(
648                 unstyle(str(report)),
649                 "2 files would be reformatted, 3 files would be left unchanged, "
650                 "2 files would fail to reformat.",
651             )
652
653     def test_is_python36(self) -> None:
654         node = black.lib2to3_parse("def f(*, arg): ...\n")
655         self.assertFalse(black.is_python36(node))
656         node = black.lib2to3_parse("def f(*, arg,): ...\n")
657         self.assertTrue(black.is_python36(node))
658         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
659         self.assertTrue(black.is_python36(node))
660         source, expected = read_data("function")
661         node = black.lib2to3_parse(source)
662         self.assertTrue(black.is_python36(node))
663         node = black.lib2to3_parse(expected)
664         self.assertTrue(black.is_python36(node))
665         source, expected = read_data("expression")
666         node = black.lib2to3_parse(source)
667         self.assertFalse(black.is_python36(node))
668         node = black.lib2to3_parse(expected)
669         self.assertFalse(black.is_python36(node))
670
671     def test_get_future_imports(self) -> None:
672         node = black.lib2to3_parse("\n")
673         self.assertEqual(set(), black.get_future_imports(node))
674         node = black.lib2to3_parse("from __future__ import black\n")
675         self.assertEqual({"black"}, black.get_future_imports(node))
676         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
677         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
678         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
679         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
680         node = black.lib2to3_parse(
681             "from __future__ import multiple\nfrom __future__ import imports\n"
682         )
683         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
684         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
685         self.assertEqual({"black"}, black.get_future_imports(node))
686         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
687         self.assertEqual({"black"}, black.get_future_imports(node))
688         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
689         self.assertEqual(set(), black.get_future_imports(node))
690         node = black.lib2to3_parse("from some.module import black\n")
691         self.assertEqual(set(), black.get_future_imports(node))
692
693     def test_debug_visitor(self) -> None:
694         source, _ = read_data("debug_visitor.py")
695         expected, _ = read_data("debug_visitor.out")
696         out_lines = []
697         err_lines = []
698
699         def out(msg: str, **kwargs: Any) -> None:
700             out_lines.append(msg)
701
702         def err(msg: str, **kwargs: Any) -> None:
703             err_lines.append(msg)
704
705         with patch("black.out", out), patch("black.err", err):
706             black.DebugVisitor.show(source)
707         actual = "\n".join(out_lines) + "\n"
708         log_name = ""
709         if expected != actual:
710             log_name = black.dump_to_file(*out_lines)
711         self.assertEqual(
712             expected,
713             actual,
714             f"AST print out is different. Actual version dumped to {log_name}",
715         )
716
717     def test_format_file_contents(self) -> None:
718         empty = ""
719         with self.assertRaises(black.NothingChanged):
720             black.format_file_contents(empty, line_length=ll, fast=False)
721         just_nl = "\n"
722         with self.assertRaises(black.NothingChanged):
723             black.format_file_contents(just_nl, line_length=ll, fast=False)
724         same = "l = [1, 2, 3]\n"
725         with self.assertRaises(black.NothingChanged):
726             black.format_file_contents(same, line_length=ll, fast=False)
727         different = "l = [1,2,3]"
728         expected = same
729         actual = black.format_file_contents(different, line_length=ll, fast=False)
730         self.assertEqual(expected, actual)
731         invalid = "return if you can"
732         with self.assertRaises(ValueError) as e:
733             black.format_file_contents(invalid, line_length=ll, fast=False)
734         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
735
736     def test_endmarker(self) -> None:
737         n = black.lib2to3_parse("\n")
738         self.assertEqual(n.type, black.syms.file_input)
739         self.assertEqual(len(n.children), 1)
740         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
741
742     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
743     def test_assertFormatEqual(self) -> None:
744         out_lines = []
745         err_lines = []
746
747         def out(msg: str, **kwargs: Any) -> None:
748             out_lines.append(msg)
749
750         def err(msg: str, **kwargs: Any) -> None:
751             err_lines.append(msg)
752
753         with patch("black.out", out), patch("black.err", err):
754             with self.assertRaises(AssertionError):
755                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
756
757         out_str = "".join(out_lines)
758         self.assertTrue("Expected tree:" in out_str)
759         self.assertTrue("Actual tree:" in out_str)
760         self.assertEqual("".join(err_lines), "")
761
762     def test_cache_broken_file(self) -> None:
763         mode = black.FileMode.AUTO_DETECT
764         with cache_dir() as workspace:
765             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
766             with cache_file.open("w") as fobj:
767                 fobj.write("this is not a pickle")
768             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
769             src = (workspace / "test.py").resolve()
770             with src.open("w") as fobj:
771                 fobj.write("print('hello')")
772             result = CliRunner().invoke(black.main, [str(src)])
773             self.assertEqual(result.exit_code, 0)
774             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
775             self.assertIn(src, cache)
776
777     def test_cache_single_file_already_cached(self) -> None:
778         mode = black.FileMode.AUTO_DETECT
779         with cache_dir() as workspace:
780             src = (workspace / "test.py").resolve()
781             with src.open("w") as fobj:
782                 fobj.write("print('hello')")
783             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
784             result = CliRunner().invoke(black.main, [str(src)])
785             self.assertEqual(result.exit_code, 0)
786             with src.open("r") as fobj:
787                 self.assertEqual(fobj.read(), "print('hello')")
788
789     @event_loop(close=False)
790     def test_cache_multiple_files(self) -> None:
791         mode = black.FileMode.AUTO_DETECT
792         with cache_dir() as workspace, patch(
793             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
794         ):
795             one = (workspace / "one.py").resolve()
796             with one.open("w") as fobj:
797                 fobj.write("print('hello')")
798             two = (workspace / "two.py").resolve()
799             with two.open("w") as fobj:
800                 fobj.write("print('hello')")
801             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
802             result = CliRunner().invoke(black.main, [str(workspace)])
803             self.assertEqual(result.exit_code, 0)
804             with one.open("r") as fobj:
805                 self.assertEqual(fobj.read(), "print('hello')")
806             with two.open("r") as fobj:
807                 self.assertEqual(fobj.read(), 'print("hello")\n')
808             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
809             self.assertIn(one, cache)
810             self.assertIn(two, cache)
811
812     def test_no_cache_when_writeback_diff(self) -> None:
813         mode = black.FileMode.AUTO_DETECT
814         with cache_dir() as workspace:
815             src = (workspace / "test.py").resolve()
816             with src.open("w") as fobj:
817                 fobj.write("print('hello')")
818             result = CliRunner().invoke(black.main, [str(src), "--diff"])
819             self.assertEqual(result.exit_code, 0)
820             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
821             self.assertFalse(cache_file.exists())
822
823     def test_no_cache_when_stdin(self) -> None:
824         mode = black.FileMode.AUTO_DETECT
825         with cache_dir():
826             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
827             self.assertEqual(result.exit_code, 0)
828             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
829             self.assertFalse(cache_file.exists())
830
831     def test_read_cache_no_cachefile(self) -> None:
832         mode = black.FileMode.AUTO_DETECT
833         with cache_dir():
834             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
835
836     def test_write_cache_read_cache(self) -> None:
837         mode = black.FileMode.AUTO_DETECT
838         with cache_dir() as workspace:
839             src = (workspace / "test.py").resolve()
840             src.touch()
841             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
842             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
843             self.assertIn(src, cache)
844             self.assertEqual(cache[src], black.get_cache_info(src))
845
846     def test_filter_cached(self) -> None:
847         with TemporaryDirectory() as workspace:
848             path = Path(workspace)
849             uncached = (path / "uncached").resolve()
850             cached = (path / "cached").resolve()
851             cached_but_changed = (path / "changed").resolve()
852             uncached.touch()
853             cached.touch()
854             cached_but_changed.touch()
855             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
856             todo, done = black.filter_cached(
857                 cache, [uncached, cached, cached_but_changed]
858             )
859             self.assertEqual(todo, [uncached, cached_but_changed])
860             self.assertEqual(done, [cached])
861
862     def test_write_cache_creates_directory_if_needed(self) -> None:
863         mode = black.FileMode.AUTO_DETECT
864         with cache_dir(exists=False) as workspace:
865             self.assertFalse(workspace.exists())
866             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
867             self.assertTrue(workspace.exists())
868
869     @event_loop(close=False)
870     def test_failed_formatting_does_not_get_cached(self) -> None:
871         mode = black.FileMode.AUTO_DETECT
872         with cache_dir() as workspace, patch(
873             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
874         ):
875             failing = (workspace / "failing.py").resolve()
876             with failing.open("w") as fobj:
877                 fobj.write("not actually python")
878             clean = (workspace / "clean.py").resolve()
879             with clean.open("w") as fobj:
880                 fobj.write('print("hello")\n')
881             result = CliRunner().invoke(black.main, [str(workspace)])
882             self.assertEqual(result.exit_code, 123)
883             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
884             self.assertNotIn(failing, cache)
885             self.assertIn(clean, cache)
886
887     def test_write_cache_write_fail(self) -> None:
888         mode = black.FileMode.AUTO_DETECT
889         with cache_dir(), patch.object(Path, "open") as mock:
890             mock.side_effect = OSError
891             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
892
893     @event_loop(close=False)
894     def test_check_diff_use_together(self) -> None:
895         with cache_dir():
896             # Files which will be reformatted.
897             src1 = (THIS_DIR / "string_quotes.py").resolve()
898             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
899             self.assertEqual(result.exit_code, 1)
900
901             # Files which will not be reformatted.
902             src2 = (THIS_DIR / "composition.py").resolve()
903             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
904             self.assertEqual(result.exit_code, 0)
905
906             # Multi file command.
907             result = CliRunner().invoke(
908                 black.main, [str(src1), str(src2), "--diff", "--check"]
909             )
910             self.assertEqual(result.exit_code, 1, result.output)
911
912     def test_no_files(self) -> None:
913         with cache_dir():
914             # Without an argument, black exits with error code 0.
915             result = CliRunner().invoke(black.main, [])
916             self.assertEqual(result.exit_code, 0)
917
918     def test_broken_symlink(self) -> None:
919         with cache_dir() as workspace:
920             symlink = workspace / "broken_link.py"
921             try:
922                 symlink.symlink_to("nonexistent.py")
923             except OSError as e:
924                 self.skipTest(f"Can't create symlinks: {e}")
925             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
926             self.assertEqual(result.exit_code, 0)
927
928     def test_read_cache_line_lengths(self) -> None:
929         mode = black.FileMode.AUTO_DETECT
930         with cache_dir() as workspace:
931             path = (workspace / "file.py").resolve()
932             path.touch()
933             black.write_cache({}, [path], 1, mode)
934             one = black.read_cache(1, mode)
935             self.assertIn(path, one)
936             two = black.read_cache(2, mode)
937             self.assertNotIn(path, two)
938
939     def test_single_file_force_pyi(self) -> None:
940         reg_mode = black.FileMode.AUTO_DETECT
941         pyi_mode = black.FileMode.PYI
942         contents, expected = read_data("force_pyi")
943         with cache_dir() as workspace:
944             path = (workspace / "file.py").resolve()
945             with open(path, "w") as fh:
946                 fh.write(contents)
947             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
948             self.assertEqual(result.exit_code, 0)
949             with open(path, "r") as fh:
950                 actual = fh.read()
951             # verify cache with --pyi is separate
952             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
953             self.assertIn(path, pyi_cache)
954             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
955             self.assertNotIn(path, normal_cache)
956         self.assertEqual(actual, expected)
957
958     @event_loop(close=False)
959     def test_multi_file_force_pyi(self) -> None:
960         reg_mode = black.FileMode.AUTO_DETECT
961         pyi_mode = black.FileMode.PYI
962         contents, expected = read_data("force_pyi")
963         with cache_dir() as workspace:
964             paths = [
965                 (workspace / "file1.py").resolve(),
966                 (workspace / "file2.py").resolve(),
967             ]
968             for path in paths:
969                 with open(path, "w") as fh:
970                     fh.write(contents)
971             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
972             self.assertEqual(result.exit_code, 0)
973             for path in paths:
974                 with open(path, "r") as fh:
975                     actual = fh.read()
976                 self.assertEqual(actual, expected)
977             # verify cache with --pyi is separate
978             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
979             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
980             for path in paths:
981                 self.assertIn(path, pyi_cache)
982                 self.assertNotIn(path, normal_cache)
983
984     def test_pipe_force_pyi(self) -> None:
985         source, expected = read_data("force_pyi")
986         result = CliRunner().invoke(black.main, ["-", "-q", "--pyi"], input=source)
987         self.assertEqual(result.exit_code, 0)
988         actual = result.output
989         self.assertFormatEqual(actual, expected)
990
991     def test_single_file_force_py36(self) -> None:
992         reg_mode = black.FileMode.AUTO_DETECT
993         py36_mode = black.FileMode.PYTHON36
994         source, expected = read_data("force_py36")
995         with cache_dir() as workspace:
996             path = (workspace / "file.py").resolve()
997             with open(path, "w") as fh:
998                 fh.write(source)
999             result = CliRunner().invoke(black.main, [str(path), "--py36"])
1000             self.assertEqual(result.exit_code, 0)
1001             with open(path, "r") as fh:
1002                 actual = fh.read()
1003             # verify cache with --py36 is separate
1004             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1005             self.assertIn(path, py36_cache)
1006             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1007             self.assertNotIn(path, normal_cache)
1008         self.assertEqual(actual, expected)
1009
1010     @event_loop(close=False)
1011     def test_multi_file_force_py36(self) -> None:
1012         reg_mode = black.FileMode.AUTO_DETECT
1013         py36_mode = black.FileMode.PYTHON36
1014         source, expected = read_data("force_py36")
1015         with cache_dir() as workspace:
1016             paths = [
1017                 (workspace / "file1.py").resolve(),
1018                 (workspace / "file2.py").resolve(),
1019             ]
1020             for path in paths:
1021                 with open(path, "w") as fh:
1022                     fh.write(source)
1023             result = CliRunner().invoke(
1024                 black.main, [str(p) for p in paths] + ["--py36"]
1025             )
1026             self.assertEqual(result.exit_code, 0)
1027             for path in paths:
1028                 with open(path, "r") as fh:
1029                     actual = fh.read()
1030                 self.assertEqual(actual, expected)
1031             # verify cache with --py36 is separate
1032             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1033             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1034             for path in paths:
1035                 self.assertIn(path, pyi_cache)
1036                 self.assertNotIn(path, normal_cache)
1037
1038     def test_pipe_force_py36(self) -> None:
1039         source, expected = read_data("force_py36")
1040         result = CliRunner().invoke(black.main, ["-", "-q", "--py36"], input=source)
1041         self.assertEqual(result.exit_code, 0)
1042         actual = result.output
1043         self.assertFormatEqual(actual, expected)
1044
1045     def test_include_exclude(self) -> None:
1046         path = THIS_DIR / "include_exclude_tests"
1047         include = re.compile(r"\.pyi?$")
1048         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1049         report = black.Report()
1050         sources: List[Path] = []
1051         expected = [
1052             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.py"),
1053             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.pyi"),
1054         ]
1055         this_abs = THIS_DIR.resolve()
1056         sources.extend(
1057             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1058         )
1059         self.assertEqual(sorted(expected), sorted(sources))
1060
1061     def test_empty_include(self) -> None:
1062         path = THIS_DIR / "include_exclude_tests"
1063         report = black.Report()
1064         empty = re.compile(r"")
1065         sources: List[Path] = []
1066         expected = [
1067             Path(path / "b/exclude/a.pie"),
1068             Path(path / "b/exclude/a.py"),
1069             Path(path / "b/exclude/a.pyi"),
1070             Path(path / "b/dont_exclude/a.pie"),
1071             Path(path / "b/dont_exclude/a.py"),
1072             Path(path / "b/dont_exclude/a.pyi"),
1073             Path(path / "b/.definitely_exclude/a.pie"),
1074             Path(path / "b/.definitely_exclude/a.py"),
1075             Path(path / "b/.definitely_exclude/a.pyi"),
1076         ]
1077         this_abs = THIS_DIR.resolve()
1078         sources.extend(
1079             black.gen_python_files_in_dir(
1080                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1081             )
1082         )
1083         self.assertEqual(sorted(expected), sorted(sources))
1084
1085     def test_empty_exclude(self) -> None:
1086         path = THIS_DIR / "include_exclude_tests"
1087         report = black.Report()
1088         empty = re.compile(r"")
1089         sources: List[Path] = []
1090         expected = [
1091             Path(path / "b/dont_exclude/a.py"),
1092             Path(path / "b/dont_exclude/a.pyi"),
1093             Path(path / "b/exclude/a.py"),
1094             Path(path / "b/exclude/a.pyi"),
1095             Path(path / "b/.definitely_exclude/a.py"),
1096             Path(path / "b/.definitely_exclude/a.pyi"),
1097         ]
1098         this_abs = THIS_DIR.resolve()
1099         sources.extend(
1100             black.gen_python_files_in_dir(
1101                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1102             )
1103         )
1104         self.assertEqual(sorted(expected), sorted(sources))
1105
1106     def test_invalid_include_exclude(self) -> None:
1107         for option in ["--include", "--exclude"]:
1108             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
1109             self.assertEqual(result.exit_code, 2)
1110
1111
1112 if __name__ == "__main__":
1113     unittest.main()