]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

f6952c721ccd0ba64f8c921e3dd10d2a75c1d6c1
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from concurrent.futures import ThreadPoolExecutor
4 from contextlib import contextmanager
5 from functools import partial
6 from io import BytesIO, TextIOWrapper
7 import os
8 from pathlib import Path
9 import re
10 import sys
11 from tempfile import TemporaryDirectory
12 from typing import Any, List, Tuple, Iterator
13 import unittest
14 from unittest.mock import patch
15
16 from click import unstyle
17 from click.testing import CliRunner
18
19 import black
20
21
22 ll = 88
23 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
24 fs = partial(black.format_str, line_length=ll)
25 THIS_FILE = Path(__file__)
26 THIS_DIR = THIS_FILE.parent
27 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28
29
30 def dump_to_stderr(*output: str) -> str:
31     return "\n" + "\n".join(output) + "\n"
32
33
34 def read_data(name: str) -> Tuple[str, str]:
35     """read_data('test_name') -> 'input', 'output'"""
36     if not name.endswith((".py", ".pyi", ".out", ".diff")):
37         name += ".py"
38     _input: List[str] = []
39     _output: List[str] = []
40     with open(THIS_DIR / name, "r", encoding="utf8") as test:
41         lines = test.readlines()
42     result = _input
43     for line in lines:
44         line = line.replace(EMPTY_LINE, "")
45         if line.rstrip() == "# output":
46             result = _output
47             continue
48
49         result.append(line)
50     if _input and not _output:
51         # If there's no output marker, treat the entire file as already pre-formatted.
52         _output = _input[:]
53     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
54
55
56 @contextmanager
57 def cache_dir(exists: bool = True) -> Iterator[Path]:
58     with TemporaryDirectory() as workspace:
59         cache_dir = Path(workspace)
60         if not exists:
61             cache_dir = cache_dir / "new"
62         with patch("black.CACHE_DIR", cache_dir):
63             yield cache_dir
64
65
66 @contextmanager
67 def event_loop(close: bool) -> Iterator[None]:
68     policy = asyncio.get_event_loop_policy()
69     old_loop = policy.get_event_loop()
70     loop = policy.new_event_loop()
71     asyncio.set_event_loop(loop)
72     try:
73         yield
74
75     finally:
76         policy.set_event_loop(old_loop)
77         if close:
78             loop.close()
79
80
81 class BlackTestCase(unittest.TestCase):
82     maxDiff = None
83
84     def assertFormatEqual(self, expected: str, actual: str) -> None:
85         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
86             bdv: black.DebugVisitor[Any]
87             black.out("Expected tree:", fg="green")
88             try:
89                 exp_node = black.lib2to3_parse(expected)
90                 bdv = black.DebugVisitor()
91                 list(bdv.visit(exp_node))
92             except Exception as ve:
93                 black.err(str(ve))
94             black.out("Actual tree:", fg="red")
95             try:
96                 exp_node = black.lib2to3_parse(actual)
97                 bdv = black.DebugVisitor()
98                 list(bdv.visit(exp_node))
99             except Exception as ve:
100                 black.err(str(ve))
101         self.assertEqual(expected, actual)
102
103     @patch("black.dump_to_file", dump_to_stderr)
104     def test_self(self) -> None:
105         source, expected = read_data("test_black")
106         actual = fs(source)
107         self.assertFormatEqual(expected, actual)
108         black.assert_equivalent(source, actual)
109         black.assert_stable(source, actual, line_length=ll)
110         self.assertFalse(ff(THIS_FILE))
111
112     @patch("black.dump_to_file", dump_to_stderr)
113     def test_black(self) -> None:
114         source, expected = read_data("../black")
115         actual = fs(source)
116         self.assertFormatEqual(expected, actual)
117         black.assert_equivalent(source, actual)
118         black.assert_stable(source, actual, line_length=ll)
119         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
120
121     def test_piping(self) -> None:
122         source, expected = read_data("../black")
123         hold_stdin, hold_stdout = sys.stdin, sys.stdout
124         try:
125             sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
126             sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
127             sys.stdin.buffer.name = "<stdin>"  # type: ignore
128             black.format_stdin_to_stdout(
129                 line_length=ll, fast=True, write_back=black.WriteBack.YES
130             )
131             sys.stdout.seek(0)
132             actual = sys.stdout.read()
133         finally:
134             sys.stdin, sys.stdout = hold_stdin, hold_stdout
135         self.assertFormatEqual(expected, actual)
136         black.assert_equivalent(source, actual)
137         black.assert_stable(source, actual, line_length=ll)
138
139     def test_piping_diff(self) -> None:
140         diff_header = re.compile(
141             rf"(STDIN|STDOUT)\t\d\d\d\d-\d\d-\d\d "
142             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
143         )
144         source, _ = read_data("expression.py")
145         expected, _ = read_data("expression.diff")
146         hold_stdin, hold_stdout = sys.stdin, sys.stdout
147         try:
148             sys.stdin = TextIOWrapper(BytesIO(source.encode("utf8")), encoding="utf8")
149             sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
150             black.format_stdin_to_stdout(
151                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
152             )
153             sys.stdout.seek(0)
154             actual = sys.stdout.read()
155             actual = diff_header.sub("[Deterministic header]", actual)
156         finally:
157             sys.stdin, sys.stdout = hold_stdin, hold_stdout
158         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
159         self.assertEqual(expected, actual)
160
161     @patch("black.dump_to_file", dump_to_stderr)
162     def test_setup(self) -> None:
163         source, expected = read_data("../setup")
164         actual = fs(source)
165         self.assertFormatEqual(expected, actual)
166         black.assert_equivalent(source, actual)
167         black.assert_stable(source, actual, line_length=ll)
168         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_function(self) -> None:
172         source, expected = read_data("function")
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, line_length=ll)
177
178     @patch("black.dump_to_file", dump_to_stderr)
179     def test_function2(self) -> None:
180         source, expected = read_data("function2")
181         actual = fs(source)
182         self.assertFormatEqual(expected, actual)
183         black.assert_equivalent(source, actual)
184         black.assert_stable(source, actual, line_length=ll)
185
186     @patch("black.dump_to_file", dump_to_stderr)
187     def test_expression(self) -> None:
188         source, expected = read_data("expression")
189         actual = fs(source)
190         self.assertFormatEqual(expected, actual)
191         black.assert_equivalent(source, actual)
192         black.assert_stable(source, actual, line_length=ll)
193
194     def test_expression_ff(self) -> None:
195         source, expected = read_data("expression")
196         tmp_file = Path(black.dump_to_file(source))
197         try:
198             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
199             with open(tmp_file, encoding="utf8") as f:
200                 actual = f.read()
201         finally:
202             os.unlink(tmp_file)
203         self.assertFormatEqual(expected, actual)
204         with patch("black.dump_to_file", dump_to_stderr):
205             black.assert_equivalent(source, actual)
206             black.assert_stable(source, actual, line_length=ll)
207
208     def test_expression_diff(self) -> None:
209         source, _ = read_data("expression.py")
210         expected, _ = read_data("expression.diff")
211         tmp_file = Path(black.dump_to_file(source))
212         diff_header = re.compile(
213             rf"{re.escape(str(tmp_file))}\t\d\d\d\d-\d\d-\d\d "
214             rf"\d\d:\d\d:\d\d\.\d\d\d\d\d\d \+\d\d\d\d"
215         )
216         hold_stdout = sys.stdout
217         try:
218             sys.stdout = TextIOWrapper(BytesIO(), encoding="utf8")
219             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
220             sys.stdout.seek(0)
221             actual = sys.stdout.read()
222             actual = diff_header.sub("[Deterministic header]", actual)
223         finally:
224             sys.stdout = hold_stdout
225             os.unlink(tmp_file)
226         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
227         if expected != actual:
228             dump = black.dump_to_file(actual)
229             msg = (
230                 f"Expected diff isn't equal to the actual. If you made changes "
231                 f"to expression.py and this is an anticipated difference, "
232                 f"overwrite tests/expression.diff with {dump}"
233             )
234             self.assertEqual(expected, actual, msg)
235
236     @patch("black.dump_to_file", dump_to_stderr)
237     def test_fstring(self) -> None:
238         source, expected = read_data("fstring")
239         actual = fs(source)
240         self.assertFormatEqual(expected, actual)
241         black.assert_equivalent(source, actual)
242         black.assert_stable(source, actual, line_length=ll)
243
244     @patch("black.dump_to_file", dump_to_stderr)
245     def test_string_quotes(self) -> None:
246         source, expected = read_data("string_quotes")
247         actual = fs(source)
248         self.assertFormatEqual(expected, actual)
249         black.assert_equivalent(source, actual)
250         black.assert_stable(source, actual, line_length=ll)
251         mode = black.FileMode.NO_STRING_NORMALIZATION
252         not_normalized = fs(source, mode=mode)
253         self.assertFormatEqual(source, not_normalized)
254         black.assert_equivalent(source, not_normalized)
255         black.assert_stable(source, not_normalized, line_length=ll, mode=mode)
256
257     @patch("black.dump_to_file", dump_to_stderr)
258     def test_slices(self) -> None:
259         source, expected = read_data("slices")
260         actual = fs(source)
261         self.assertFormatEqual(expected, actual)
262         black.assert_equivalent(source, actual)
263         black.assert_stable(source, actual, line_length=ll)
264
265     @patch("black.dump_to_file", dump_to_stderr)
266     def test_comments(self) -> None:
267         source, expected = read_data("comments")
268         actual = fs(source)
269         self.assertFormatEqual(expected, actual)
270         black.assert_equivalent(source, actual)
271         black.assert_stable(source, actual, line_length=ll)
272
273     @patch("black.dump_to_file", dump_to_stderr)
274     def test_comments2(self) -> None:
275         source, expected = read_data("comments2")
276         actual = fs(source)
277         self.assertFormatEqual(expected, actual)
278         black.assert_equivalent(source, actual)
279         black.assert_stable(source, actual, line_length=ll)
280
281     @patch("black.dump_to_file", dump_to_stderr)
282     def test_comments3(self) -> None:
283         source, expected = read_data("comments3")
284         actual = fs(source)
285         self.assertFormatEqual(expected, actual)
286         black.assert_equivalent(source, actual)
287         black.assert_stable(source, actual, line_length=ll)
288
289     @patch("black.dump_to_file", dump_to_stderr)
290     def test_comments4(self) -> None:
291         source, expected = read_data("comments4")
292         actual = fs(source)
293         self.assertFormatEqual(expected, actual)
294         black.assert_equivalent(source, actual)
295         black.assert_stable(source, actual, line_length=ll)
296
297     @patch("black.dump_to_file", dump_to_stderr)
298     def test_comments5(self) -> None:
299         source, expected = read_data("comments5")
300         actual = fs(source)
301         self.assertFormatEqual(expected, actual)
302         black.assert_equivalent(source, actual)
303         black.assert_stable(source, actual, line_length=ll)
304
305     @patch("black.dump_to_file", dump_to_stderr)
306     def test_cantfit(self) -> None:
307         source, expected = read_data("cantfit")
308         actual = fs(source)
309         self.assertFormatEqual(expected, actual)
310         black.assert_equivalent(source, actual)
311         black.assert_stable(source, actual, line_length=ll)
312
313     @patch("black.dump_to_file", dump_to_stderr)
314     def test_import_spacing(self) -> None:
315         source, expected = read_data("import_spacing")
316         actual = fs(source)
317         self.assertFormatEqual(expected, actual)
318         black.assert_equivalent(source, actual)
319         black.assert_stable(source, actual, line_length=ll)
320
321     @patch("black.dump_to_file", dump_to_stderr)
322     def test_composition(self) -> None:
323         source, expected = read_data("composition")
324         actual = fs(source)
325         self.assertFormatEqual(expected, actual)
326         black.assert_equivalent(source, actual)
327         black.assert_stable(source, actual, line_length=ll)
328
329     @patch("black.dump_to_file", dump_to_stderr)
330     def test_empty_lines(self) -> None:
331         source, expected = read_data("empty_lines")
332         actual = fs(source)
333         self.assertFormatEqual(expected, actual)
334         black.assert_equivalent(source, actual)
335         black.assert_stable(source, actual, line_length=ll)
336
337     @patch("black.dump_to_file", dump_to_stderr)
338     def test_string_prefixes(self) -> None:
339         source, expected = read_data("string_prefixes")
340         actual = fs(source)
341         self.assertFormatEqual(expected, actual)
342         black.assert_equivalent(source, actual)
343         black.assert_stable(source, actual, line_length=ll)
344
345     @patch("black.dump_to_file", dump_to_stderr)
346     def test_python2(self) -> None:
347         source, expected = read_data("python2")
348         actual = fs(source)
349         self.assertFormatEqual(expected, actual)
350         # black.assert_equivalent(source, actual)
351         black.assert_stable(source, actual, line_length=ll)
352
353     @patch("black.dump_to_file", dump_to_stderr)
354     def test_python2_unicode_literals(self) -> None:
355         source, expected = read_data("python2_unicode_literals")
356         actual = fs(source)
357         self.assertFormatEqual(expected, actual)
358         black.assert_stable(source, actual, line_length=ll)
359
360     @patch("black.dump_to_file", dump_to_stderr)
361     def test_stub(self) -> None:
362         mode = black.FileMode.PYI
363         source, expected = read_data("stub.pyi")
364         actual = fs(source, mode=mode)
365         self.assertFormatEqual(expected, actual)
366         black.assert_stable(source, actual, line_length=ll, mode=mode)
367
368     @patch("black.dump_to_file", dump_to_stderr)
369     def test_fmtonoff(self) -> None:
370         source, expected = read_data("fmtonoff")
371         actual = fs(source)
372         self.assertFormatEqual(expected, actual)
373         black.assert_equivalent(source, actual)
374         black.assert_stable(source, actual, line_length=ll)
375
376     @patch("black.dump_to_file", dump_to_stderr)
377     def test_remove_empty_parentheses_after_class(self) -> None:
378         source, expected = read_data("class_blank_parentheses")
379         actual = fs(source)
380         self.assertFormatEqual(expected, actual)
381         black.assert_equivalent(source, actual)
382         black.assert_stable(source, actual, line_length=ll)
383
384     @patch("black.dump_to_file", dump_to_stderr)
385     def test_new_line_between_class_and_code(self) -> None:
386         source, expected = read_data("class_methods_new_line")
387         actual = fs(source)
388         self.assertFormatEqual(expected, actual)
389         black.assert_equivalent(source, actual)
390         black.assert_stable(source, actual, line_length=ll)
391
392     def test_report_verbose(self) -> None:
393         report = black.Report(verbose=True)
394         out_lines = []
395         err_lines = []
396
397         def out(msg: str, **kwargs: Any) -> None:
398             out_lines.append(msg)
399
400         def err(msg: str, **kwargs: Any) -> None:
401             err_lines.append(msg)
402
403         with patch("black.out", out), patch("black.err", err):
404             report.done(Path("f1"), black.Changed.NO)
405             self.assertEqual(len(out_lines), 1)
406             self.assertEqual(len(err_lines), 0)
407             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
408             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
409             self.assertEqual(report.return_code, 0)
410             report.done(Path("f2"), black.Changed.YES)
411             self.assertEqual(len(out_lines), 2)
412             self.assertEqual(len(err_lines), 0)
413             self.assertEqual(out_lines[-1], "reformatted f2")
414             self.assertEqual(
415                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
416             )
417             report.done(Path("f3"), black.Changed.CACHED)
418             self.assertEqual(len(out_lines), 3)
419             self.assertEqual(len(err_lines), 0)
420             self.assertEqual(
421                 out_lines[-1], "f3 wasn't modified on disk since last run."
422             )
423             self.assertEqual(
424                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
425             )
426             self.assertEqual(report.return_code, 0)
427             report.check = True
428             self.assertEqual(report.return_code, 1)
429             report.check = False
430             report.failed(Path("e1"), "boom")
431             self.assertEqual(len(out_lines), 3)
432             self.assertEqual(len(err_lines), 1)
433             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
434             self.assertEqual(
435                 unstyle(str(report)),
436                 "1 file reformatted, 2 files left unchanged, "
437                 "1 file failed to reformat.",
438             )
439             self.assertEqual(report.return_code, 123)
440             report.done(Path("f3"), black.Changed.YES)
441             self.assertEqual(len(out_lines), 4)
442             self.assertEqual(len(err_lines), 1)
443             self.assertEqual(out_lines[-1], "reformatted f3")
444             self.assertEqual(
445                 unstyle(str(report)),
446                 "2 files reformatted, 2 files left unchanged, "
447                 "1 file failed to reformat.",
448             )
449             self.assertEqual(report.return_code, 123)
450             report.failed(Path("e2"), "boom")
451             self.assertEqual(len(out_lines), 4)
452             self.assertEqual(len(err_lines), 2)
453             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
454             self.assertEqual(
455                 unstyle(str(report)),
456                 "2 files reformatted, 2 files left unchanged, "
457                 "2 files failed to reformat.",
458             )
459             self.assertEqual(report.return_code, 123)
460             report.path_ignored(Path("wat"), "no match")
461             self.assertEqual(len(out_lines), 5)
462             self.assertEqual(len(err_lines), 2)
463             self.assertEqual(out_lines[-1], "wat ignored: no match")
464             self.assertEqual(
465                 unstyle(str(report)),
466                 "2 files reformatted, 2 files left unchanged, "
467                 "2 files failed to reformat.",
468             )
469             self.assertEqual(report.return_code, 123)
470             report.done(Path("f4"), black.Changed.NO)
471             self.assertEqual(len(out_lines), 6)
472             self.assertEqual(len(err_lines), 2)
473             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
474             self.assertEqual(
475                 unstyle(str(report)),
476                 "2 files reformatted, 3 files left unchanged, "
477                 "2 files failed to reformat.",
478             )
479             self.assertEqual(report.return_code, 123)
480             report.check = True
481             self.assertEqual(
482                 unstyle(str(report)),
483                 "2 files would be reformatted, 3 files would be left unchanged, "
484                 "2 files would fail to reformat.",
485             )
486
487     def test_report_quiet(self) -> None:
488         report = black.Report(quiet=True)
489         out_lines = []
490         err_lines = []
491
492         def out(msg: str, **kwargs: Any) -> None:
493             out_lines.append(msg)
494
495         def err(msg: str, **kwargs: Any) -> None:
496             err_lines.append(msg)
497
498         with patch("black.out", out), patch("black.err", err):
499             report.done(Path("f1"), black.Changed.NO)
500             self.assertEqual(len(out_lines), 0)
501             self.assertEqual(len(err_lines), 0)
502             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
503             self.assertEqual(report.return_code, 0)
504             report.done(Path("f2"), black.Changed.YES)
505             self.assertEqual(len(out_lines), 0)
506             self.assertEqual(len(err_lines), 0)
507             self.assertEqual(
508                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
509             )
510             report.done(Path("f3"), black.Changed.CACHED)
511             self.assertEqual(len(out_lines), 0)
512             self.assertEqual(len(err_lines), 0)
513             self.assertEqual(
514                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
515             )
516             self.assertEqual(report.return_code, 0)
517             report.check = True
518             self.assertEqual(report.return_code, 1)
519             report.check = False
520             report.failed(Path("e1"), "boom")
521             self.assertEqual(len(out_lines), 0)
522             self.assertEqual(len(err_lines), 1)
523             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
524             self.assertEqual(
525                 unstyle(str(report)),
526                 "1 file reformatted, 2 files left unchanged, "
527                 "1 file failed to reformat.",
528             )
529             self.assertEqual(report.return_code, 123)
530             report.done(Path("f3"), black.Changed.YES)
531             self.assertEqual(len(out_lines), 0)
532             self.assertEqual(len(err_lines), 1)
533             self.assertEqual(
534                 unstyle(str(report)),
535                 "2 files reformatted, 2 files left unchanged, "
536                 "1 file failed to reformat.",
537             )
538             self.assertEqual(report.return_code, 123)
539             report.failed(Path("e2"), "boom")
540             self.assertEqual(len(out_lines), 0)
541             self.assertEqual(len(err_lines), 2)
542             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
543             self.assertEqual(
544                 unstyle(str(report)),
545                 "2 files reformatted, 2 files left unchanged, "
546                 "2 files failed to reformat.",
547             )
548             self.assertEqual(report.return_code, 123)
549             report.path_ignored(Path("wat"), "no match")
550             self.assertEqual(len(out_lines), 0)
551             self.assertEqual(len(err_lines), 2)
552             self.assertEqual(
553                 unstyle(str(report)),
554                 "2 files reformatted, 2 files left unchanged, "
555                 "2 files failed to reformat.",
556             )
557             self.assertEqual(report.return_code, 123)
558             report.done(Path("f4"), black.Changed.NO)
559             self.assertEqual(len(out_lines), 0)
560             self.assertEqual(len(err_lines), 2)
561             self.assertEqual(
562                 unstyle(str(report)),
563                 "2 files reformatted, 3 files left unchanged, "
564                 "2 files failed to reformat.",
565             )
566             self.assertEqual(report.return_code, 123)
567             report.check = True
568             self.assertEqual(
569                 unstyle(str(report)),
570                 "2 files would be reformatted, 3 files would be left unchanged, "
571                 "2 files would fail to reformat.",
572             )
573
574     def test_report_normal(self) -> None:
575         report = black.Report()
576         out_lines = []
577         err_lines = []
578
579         def out(msg: str, **kwargs: Any) -> None:
580             out_lines.append(msg)
581
582         def err(msg: str, **kwargs: Any) -> None:
583             err_lines.append(msg)
584
585         with patch("black.out", out), patch("black.err", err):
586             report.done(Path("f1"), black.Changed.NO)
587             self.assertEqual(len(out_lines), 0)
588             self.assertEqual(len(err_lines), 0)
589             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
590             self.assertEqual(report.return_code, 0)
591             report.done(Path("f2"), black.Changed.YES)
592             self.assertEqual(len(out_lines), 1)
593             self.assertEqual(len(err_lines), 0)
594             self.assertEqual(out_lines[-1], "reformatted f2")
595             self.assertEqual(
596                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
597             )
598             report.done(Path("f3"), black.Changed.CACHED)
599             self.assertEqual(len(out_lines), 1)
600             self.assertEqual(len(err_lines), 0)
601             self.assertEqual(out_lines[-1], "reformatted f2")
602             self.assertEqual(
603                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
604             )
605             self.assertEqual(report.return_code, 0)
606             report.check = True
607             self.assertEqual(report.return_code, 1)
608             report.check = False
609             report.failed(Path("e1"), "boom")
610             self.assertEqual(len(out_lines), 1)
611             self.assertEqual(len(err_lines), 1)
612             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
613             self.assertEqual(
614                 unstyle(str(report)),
615                 "1 file reformatted, 2 files left unchanged, "
616                 "1 file failed to reformat.",
617             )
618             self.assertEqual(report.return_code, 123)
619             report.done(Path("f3"), black.Changed.YES)
620             self.assertEqual(len(out_lines), 2)
621             self.assertEqual(len(err_lines), 1)
622             self.assertEqual(out_lines[-1], "reformatted f3")
623             self.assertEqual(
624                 unstyle(str(report)),
625                 "2 files reformatted, 2 files left unchanged, "
626                 "1 file failed to reformat.",
627             )
628             self.assertEqual(report.return_code, 123)
629             report.failed(Path("e2"), "boom")
630             self.assertEqual(len(out_lines), 2)
631             self.assertEqual(len(err_lines), 2)
632             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
633             self.assertEqual(
634                 unstyle(str(report)),
635                 "2 files reformatted, 2 files left unchanged, "
636                 "2 files failed to reformat.",
637             )
638             self.assertEqual(report.return_code, 123)
639             report.path_ignored(Path("wat"), "no match")
640             self.assertEqual(len(out_lines), 2)
641             self.assertEqual(len(err_lines), 2)
642             self.assertEqual(
643                 unstyle(str(report)),
644                 "2 files reformatted, 2 files left unchanged, "
645                 "2 files failed to reformat.",
646             )
647             self.assertEqual(report.return_code, 123)
648             report.done(Path("f4"), black.Changed.NO)
649             self.assertEqual(len(out_lines), 2)
650             self.assertEqual(len(err_lines), 2)
651             self.assertEqual(
652                 unstyle(str(report)),
653                 "2 files reformatted, 3 files left unchanged, "
654                 "2 files failed to reformat.",
655             )
656             self.assertEqual(report.return_code, 123)
657             report.check = True
658             self.assertEqual(
659                 unstyle(str(report)),
660                 "2 files would be reformatted, 3 files would be left unchanged, "
661                 "2 files would fail to reformat.",
662             )
663
664     def test_is_python36(self) -> None:
665         node = black.lib2to3_parse("def f(*, arg): ...\n")
666         self.assertFalse(black.is_python36(node))
667         node = black.lib2to3_parse("def f(*, arg,): ...\n")
668         self.assertTrue(black.is_python36(node))
669         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
670         self.assertTrue(black.is_python36(node))
671         source, expected = read_data("function")
672         node = black.lib2to3_parse(source)
673         self.assertTrue(black.is_python36(node))
674         node = black.lib2to3_parse(expected)
675         self.assertTrue(black.is_python36(node))
676         source, expected = read_data("expression")
677         node = black.lib2to3_parse(source)
678         self.assertFalse(black.is_python36(node))
679         node = black.lib2to3_parse(expected)
680         self.assertFalse(black.is_python36(node))
681
682     def test_get_future_imports(self) -> None:
683         node = black.lib2to3_parse("\n")
684         self.assertEqual(set(), black.get_future_imports(node))
685         node = black.lib2to3_parse("from __future__ import black\n")
686         self.assertEqual({"black"}, black.get_future_imports(node))
687         node = black.lib2to3_parse("from __future__ import multiple, imports\n")
688         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
689         node = black.lib2to3_parse("from __future__ import (parenthesized, imports)\n")
690         self.assertEqual({"parenthesized", "imports"}, black.get_future_imports(node))
691         node = black.lib2to3_parse(
692             "from __future__ import multiple\nfrom __future__ import imports\n"
693         )
694         self.assertEqual({"multiple", "imports"}, black.get_future_imports(node))
695         node = black.lib2to3_parse("# comment\nfrom __future__ import black\n")
696         self.assertEqual({"black"}, black.get_future_imports(node))
697         node = black.lib2to3_parse('"""docstring"""\nfrom __future__ import black\n')
698         self.assertEqual({"black"}, black.get_future_imports(node))
699         node = black.lib2to3_parse("some(other, code)\nfrom __future__ import black\n")
700         self.assertEqual(set(), black.get_future_imports(node))
701         node = black.lib2to3_parse("from some.module import black\n")
702         self.assertEqual(set(), black.get_future_imports(node))
703
704     def test_debug_visitor(self) -> None:
705         source, _ = read_data("debug_visitor.py")
706         expected, _ = read_data("debug_visitor.out")
707         out_lines = []
708         err_lines = []
709
710         def out(msg: str, **kwargs: Any) -> None:
711             out_lines.append(msg)
712
713         def err(msg: str, **kwargs: Any) -> None:
714             err_lines.append(msg)
715
716         with patch("black.out", out), patch("black.err", err):
717             black.DebugVisitor.show(source)
718         actual = "\n".join(out_lines) + "\n"
719         log_name = ""
720         if expected != actual:
721             log_name = black.dump_to_file(*out_lines)
722         self.assertEqual(
723             expected,
724             actual,
725             f"AST print out is different. Actual version dumped to {log_name}",
726         )
727
728     def test_format_file_contents(self) -> None:
729         empty = ""
730         with self.assertRaises(black.NothingChanged):
731             black.format_file_contents(empty, line_length=ll, fast=False)
732         just_nl = "\n"
733         with self.assertRaises(black.NothingChanged):
734             black.format_file_contents(just_nl, line_length=ll, fast=False)
735         same = "l = [1, 2, 3]\n"
736         with self.assertRaises(black.NothingChanged):
737             black.format_file_contents(same, line_length=ll, fast=False)
738         different = "l = [1,2,3]"
739         expected = same
740         actual = black.format_file_contents(different, line_length=ll, fast=False)
741         self.assertEqual(expected, actual)
742         invalid = "return if you can"
743         with self.assertRaises(ValueError) as e:
744             black.format_file_contents(invalid, line_length=ll, fast=False)
745         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
746
747     def test_endmarker(self) -> None:
748         n = black.lib2to3_parse("\n")
749         self.assertEqual(n.type, black.syms.file_input)
750         self.assertEqual(len(n.children), 1)
751         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
752
753     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
754     def test_assertFormatEqual(self) -> None:
755         out_lines = []
756         err_lines = []
757
758         def out(msg: str, **kwargs: Any) -> None:
759             out_lines.append(msg)
760
761         def err(msg: str, **kwargs: Any) -> None:
762             err_lines.append(msg)
763
764         with patch("black.out", out), patch("black.err", err):
765             with self.assertRaises(AssertionError):
766                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
767
768         out_str = "".join(out_lines)
769         self.assertTrue("Expected tree:" in out_str)
770         self.assertTrue("Actual tree:" in out_str)
771         self.assertEqual("".join(err_lines), "")
772
773     def test_cache_broken_file(self) -> None:
774         mode = black.FileMode.AUTO_DETECT
775         with cache_dir() as workspace:
776             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
777             with cache_file.open("w") as fobj:
778                 fobj.write("this is not a pickle")
779             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
780             src = (workspace / "test.py").resolve()
781             with src.open("w") as fobj:
782                 fobj.write("print('hello')")
783             result = CliRunner().invoke(black.main, [str(src)])
784             self.assertEqual(result.exit_code, 0)
785             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
786             self.assertIn(src, cache)
787
788     def test_cache_single_file_already_cached(self) -> None:
789         mode = black.FileMode.AUTO_DETECT
790         with cache_dir() as workspace:
791             src = (workspace / "test.py").resolve()
792             with src.open("w") as fobj:
793                 fobj.write("print('hello')")
794             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
795             result = CliRunner().invoke(black.main, [str(src)])
796             self.assertEqual(result.exit_code, 0)
797             with src.open("r") as fobj:
798                 self.assertEqual(fobj.read(), "print('hello')")
799
800     @event_loop(close=False)
801     def test_cache_multiple_files(self) -> None:
802         mode = black.FileMode.AUTO_DETECT
803         with cache_dir() as workspace, patch(
804             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
805         ):
806             one = (workspace / "one.py").resolve()
807             with one.open("w") as fobj:
808                 fobj.write("print('hello')")
809             two = (workspace / "two.py").resolve()
810             with two.open("w") as fobj:
811                 fobj.write("print('hello')")
812             black.write_cache({}, [one], black.DEFAULT_LINE_LENGTH, mode)
813             result = CliRunner().invoke(black.main, [str(workspace)])
814             self.assertEqual(result.exit_code, 0)
815             with one.open("r") as fobj:
816                 self.assertEqual(fobj.read(), "print('hello')")
817             with two.open("r") as fobj:
818                 self.assertEqual(fobj.read(), 'print("hello")\n')
819             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
820             self.assertIn(one, cache)
821             self.assertIn(two, cache)
822
823     def test_no_cache_when_writeback_diff(self) -> None:
824         mode = black.FileMode.AUTO_DETECT
825         with cache_dir() as workspace:
826             src = (workspace / "test.py").resolve()
827             with src.open("w") as fobj:
828                 fobj.write("print('hello')")
829             result = CliRunner().invoke(black.main, [str(src), "--diff"])
830             self.assertEqual(result.exit_code, 0)
831             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
832             self.assertFalse(cache_file.exists())
833
834     def test_no_cache_when_stdin(self) -> None:
835         mode = black.FileMode.AUTO_DETECT
836         with cache_dir():
837             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
838             self.assertEqual(result.exit_code, 0)
839             cache_file = black.get_cache_file(black.DEFAULT_LINE_LENGTH, mode)
840             self.assertFalse(cache_file.exists())
841
842     def test_read_cache_no_cachefile(self) -> None:
843         mode = black.FileMode.AUTO_DETECT
844         with cache_dir():
845             self.assertEqual(black.read_cache(black.DEFAULT_LINE_LENGTH, mode), {})
846
847     def test_write_cache_read_cache(self) -> None:
848         mode = black.FileMode.AUTO_DETECT
849         with cache_dir() as workspace:
850             src = (workspace / "test.py").resolve()
851             src.touch()
852             black.write_cache({}, [src], black.DEFAULT_LINE_LENGTH, mode)
853             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
854             self.assertIn(src, cache)
855             self.assertEqual(cache[src], black.get_cache_info(src))
856
857     def test_filter_cached(self) -> None:
858         with TemporaryDirectory() as workspace:
859             path = Path(workspace)
860             uncached = (path / "uncached").resolve()
861             cached = (path / "cached").resolve()
862             cached_but_changed = (path / "changed").resolve()
863             uncached.touch()
864             cached.touch()
865             cached_but_changed.touch()
866             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
867             todo, done = black.filter_cached(
868                 cache, {uncached, cached, cached_but_changed}
869             )
870             self.assertEqual(todo, {uncached, cached_but_changed})
871             self.assertEqual(done, {cached})
872
873     def test_write_cache_creates_directory_if_needed(self) -> None:
874         mode = black.FileMode.AUTO_DETECT
875         with cache_dir(exists=False) as workspace:
876             self.assertFalse(workspace.exists())
877             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
878             self.assertTrue(workspace.exists())
879
880     @event_loop(close=False)
881     def test_failed_formatting_does_not_get_cached(self) -> None:
882         mode = black.FileMode.AUTO_DETECT
883         with cache_dir() as workspace, patch(
884             "black.ProcessPoolExecutor", new=ThreadPoolExecutor
885         ):
886             failing = (workspace / "failing.py").resolve()
887             with failing.open("w") as fobj:
888                 fobj.write("not actually python")
889             clean = (workspace / "clean.py").resolve()
890             with clean.open("w") as fobj:
891                 fobj.write('print("hello")\n')
892             result = CliRunner().invoke(black.main, [str(workspace)])
893             self.assertEqual(result.exit_code, 123)
894             cache = black.read_cache(black.DEFAULT_LINE_LENGTH, mode)
895             self.assertNotIn(failing, cache)
896             self.assertIn(clean, cache)
897
898     def test_write_cache_write_fail(self) -> None:
899         mode = black.FileMode.AUTO_DETECT
900         with cache_dir(), patch.object(Path, "open") as mock:
901             mock.side_effect = OSError
902             black.write_cache({}, [], black.DEFAULT_LINE_LENGTH, mode)
903
904     @event_loop(close=False)
905     def test_check_diff_use_together(self) -> None:
906         with cache_dir():
907             # Files which will be reformatted.
908             src1 = (THIS_DIR / "string_quotes.py").resolve()
909             result = CliRunner().invoke(black.main, [str(src1), "--diff", "--check"])
910             self.assertEqual(result.exit_code, 1)
911
912             # Files which will not be reformatted.
913             src2 = (THIS_DIR / "composition.py").resolve()
914             result = CliRunner().invoke(black.main, [str(src2), "--diff", "--check"])
915             self.assertEqual(result.exit_code, 0)
916
917             # Multi file command.
918             result = CliRunner().invoke(
919                 black.main, [str(src1), str(src2), "--diff", "--check"]
920             )
921             self.assertEqual(result.exit_code, 1, result.output)
922
923     def test_no_files(self) -> None:
924         with cache_dir():
925             # Without an argument, black exits with error code 0.
926             result = CliRunner().invoke(black.main, [])
927             self.assertEqual(result.exit_code, 0)
928
929     def test_broken_symlink(self) -> None:
930         with cache_dir() as workspace:
931             symlink = workspace / "broken_link.py"
932             try:
933                 symlink.symlink_to("nonexistent.py")
934             except OSError as e:
935                 self.skipTest(f"Can't create symlinks: {e}")
936             result = CliRunner().invoke(black.main, [str(workspace.resolve())])
937             self.assertEqual(result.exit_code, 0)
938
939     def test_read_cache_line_lengths(self) -> None:
940         mode = black.FileMode.AUTO_DETECT
941         with cache_dir() as workspace:
942             path = (workspace / "file.py").resolve()
943             path.touch()
944             black.write_cache({}, [path], 1, mode)
945             one = black.read_cache(1, mode)
946             self.assertIn(path, one)
947             two = black.read_cache(2, mode)
948             self.assertNotIn(path, two)
949
950     def test_single_file_force_pyi(self) -> None:
951         reg_mode = black.FileMode.AUTO_DETECT
952         pyi_mode = black.FileMode.PYI
953         contents, expected = read_data("force_pyi")
954         with cache_dir() as workspace:
955             path = (workspace / "file.py").resolve()
956             with open(path, "w") as fh:
957                 fh.write(contents)
958             result = CliRunner().invoke(black.main, [str(path), "--pyi"])
959             self.assertEqual(result.exit_code, 0)
960             with open(path, "r") as fh:
961                 actual = fh.read()
962             # verify cache with --pyi is separate
963             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
964             self.assertIn(path, pyi_cache)
965             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
966             self.assertNotIn(path, normal_cache)
967         self.assertEqual(actual, expected)
968
969     @event_loop(close=False)
970     def test_multi_file_force_pyi(self) -> None:
971         reg_mode = black.FileMode.AUTO_DETECT
972         pyi_mode = black.FileMode.PYI
973         contents, expected = read_data("force_pyi")
974         with cache_dir() as workspace:
975             paths = [
976                 (workspace / "file1.py").resolve(),
977                 (workspace / "file2.py").resolve(),
978             ]
979             for path in paths:
980                 with open(path, "w") as fh:
981                     fh.write(contents)
982             result = CliRunner().invoke(black.main, [str(p) for p in paths] + ["--pyi"])
983             self.assertEqual(result.exit_code, 0)
984             for path in paths:
985                 with open(path, "r") as fh:
986                     actual = fh.read()
987                 self.assertEqual(actual, expected)
988             # verify cache with --pyi is separate
989             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, pyi_mode)
990             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
991             for path in paths:
992                 self.assertIn(path, pyi_cache)
993                 self.assertNotIn(path, normal_cache)
994
995     def test_pipe_force_pyi(self) -> None:
996         source, expected = read_data("force_pyi")
997         result = CliRunner().invoke(black.main, ["-", "-q", "--pyi"], input=source)
998         self.assertEqual(result.exit_code, 0)
999         actual = result.output
1000         self.assertFormatEqual(actual, expected)
1001
1002     def test_single_file_force_py36(self) -> None:
1003         reg_mode = black.FileMode.AUTO_DETECT
1004         py36_mode = black.FileMode.PYTHON36
1005         source, expected = read_data("force_py36")
1006         with cache_dir() as workspace:
1007             path = (workspace / "file.py").resolve()
1008             with open(path, "w") as fh:
1009                 fh.write(source)
1010             result = CliRunner().invoke(black.main, [str(path), "--py36"])
1011             self.assertEqual(result.exit_code, 0)
1012             with open(path, "r") as fh:
1013                 actual = fh.read()
1014             # verify cache with --py36 is separate
1015             py36_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1016             self.assertIn(path, py36_cache)
1017             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1018             self.assertNotIn(path, normal_cache)
1019         self.assertEqual(actual, expected)
1020
1021     @event_loop(close=False)
1022     def test_multi_file_force_py36(self) -> None:
1023         reg_mode = black.FileMode.AUTO_DETECT
1024         py36_mode = black.FileMode.PYTHON36
1025         source, expected = read_data("force_py36")
1026         with cache_dir() as workspace:
1027             paths = [
1028                 (workspace / "file1.py").resolve(),
1029                 (workspace / "file2.py").resolve(),
1030             ]
1031             for path in paths:
1032                 with open(path, "w") as fh:
1033                     fh.write(source)
1034             result = CliRunner().invoke(
1035                 black.main, [str(p) for p in paths] + ["--py36"]
1036             )
1037             self.assertEqual(result.exit_code, 0)
1038             for path in paths:
1039                 with open(path, "r") as fh:
1040                     actual = fh.read()
1041                 self.assertEqual(actual, expected)
1042             # verify cache with --py36 is separate
1043             pyi_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, py36_mode)
1044             normal_cache = black.read_cache(black.DEFAULT_LINE_LENGTH, reg_mode)
1045             for path in paths:
1046                 self.assertIn(path, pyi_cache)
1047                 self.assertNotIn(path, normal_cache)
1048
1049     def test_pipe_force_py36(self) -> None:
1050         source, expected = read_data("force_py36")
1051         result = CliRunner().invoke(black.main, ["-", "-q", "--py36"], input=source)
1052         self.assertEqual(result.exit_code, 0)
1053         actual = result.output
1054         self.assertFormatEqual(actual, expected)
1055
1056     def test_include_exclude(self) -> None:
1057         path = THIS_DIR / "include_exclude_tests"
1058         include = re.compile(r"\.pyi?$")
1059         exclude = re.compile(r"/exclude/|/\.definitely_exclude/")
1060         report = black.Report()
1061         sources: List[Path] = []
1062         expected = [
1063             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.py"),
1064             Path(THIS_DIR / "include_exclude_tests/b/dont_exclude/a.pyi"),
1065         ]
1066         this_abs = THIS_DIR.resolve()
1067         sources.extend(
1068             black.gen_python_files_in_dir(path, this_abs, include, exclude, report)
1069         )
1070         self.assertEqual(sorted(expected), sorted(sources))
1071
1072     def test_empty_include(self) -> None:
1073         path = THIS_DIR / "include_exclude_tests"
1074         report = black.Report()
1075         empty = re.compile(r"")
1076         sources: List[Path] = []
1077         expected = [
1078             Path(path / "b/exclude/a.pie"),
1079             Path(path / "b/exclude/a.py"),
1080             Path(path / "b/exclude/a.pyi"),
1081             Path(path / "b/dont_exclude/a.pie"),
1082             Path(path / "b/dont_exclude/a.py"),
1083             Path(path / "b/dont_exclude/a.pyi"),
1084             Path(path / "b/.definitely_exclude/a.pie"),
1085             Path(path / "b/.definitely_exclude/a.py"),
1086             Path(path / "b/.definitely_exclude/a.pyi"),
1087         ]
1088         this_abs = THIS_DIR.resolve()
1089         sources.extend(
1090             black.gen_python_files_in_dir(
1091                 path, this_abs, empty, re.compile(black.DEFAULT_EXCLUDES), report
1092             )
1093         )
1094         self.assertEqual(sorted(expected), sorted(sources))
1095
1096     def test_empty_exclude(self) -> None:
1097         path = THIS_DIR / "include_exclude_tests"
1098         report = black.Report()
1099         empty = re.compile(r"")
1100         sources: List[Path] = []
1101         expected = [
1102             Path(path / "b/dont_exclude/a.py"),
1103             Path(path / "b/dont_exclude/a.pyi"),
1104             Path(path / "b/exclude/a.py"),
1105             Path(path / "b/exclude/a.pyi"),
1106             Path(path / "b/.definitely_exclude/a.py"),
1107             Path(path / "b/.definitely_exclude/a.pyi"),
1108         ]
1109         this_abs = THIS_DIR.resolve()
1110         sources.extend(
1111             black.gen_python_files_in_dir(
1112                 path, this_abs, re.compile(black.DEFAULT_INCLUDES), empty, report
1113             )
1114         )
1115         self.assertEqual(sorted(expected), sorted(sources))
1116
1117     def test_invalid_include_exclude(self) -> None:
1118         for option in ["--include", "--exclude"]:
1119             result = CliRunner().invoke(black.main, ["-", option, "**()(!!*)"])
1120             self.assertEqual(result.exit_code, 2)
1121
1122     def test_preserves_line_endings(self) -> None:
1123         with TemporaryDirectory() as workspace:
1124             test_file = Path(workspace) / "test.py"
1125             for nl in ["\n", "\r\n"]:
1126                 contents = nl.join(["def f(  ):", "    pass"])
1127                 test_file.write_bytes(contents.encode())
1128                 ff(test_file, write_back=black.WriteBack.YES)
1129                 updated_contents: bytes = test_file.read_bytes()
1130                 self.assertIn(nl.encode(), updated_contents)  # type: ignore
1131                 if nl == "\n":
1132                     self.assertNotIn(b"\r\n", updated_contents)  # type: ignore
1133
1134
1135 if __name__ == "__main__":
1136     unittest.main()