]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Store pickles for 3.8.0a0
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 import asyncio
3 from contextlib import contextmanager
4 from functools import partial
5 from io import StringIO
6 import os
7 from pathlib import Path
8 import sys
9 from tempfile import TemporaryDirectory
10 from typing import Any, List, Tuple, Iterator
11 import unittest
12 from unittest.mock import patch
13
14 from click import unstyle
15 from click.testing import CliRunner
16
17 import black
18
19 ll = 88
20 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
21 fs = partial(black.format_str, line_length=ll)
22 THIS_FILE = Path(__file__)
23 THIS_DIR = THIS_FILE.parent
24 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
25
26
27 def dump_to_stderr(*output: str) -> str:
28     return "\n" + "\n".join(output) + "\n"
29
30
31 def read_data(name: str) -> Tuple[str, str]:
32     """read_data('test_name') -> 'input', 'output'"""
33     if not name.endswith((".py", ".out", ".diff")):
34         name += ".py"
35     _input: List[str] = []
36     _output: List[str] = []
37     with open(THIS_DIR / name, "r", encoding="utf8") as test:
38         lines = test.readlines()
39     result = _input
40     for line in lines:
41         line = line.replace(EMPTY_LINE, "")
42         if line.rstrip() == "# output":
43             result = _output
44             continue
45
46         result.append(line)
47     if _input and not _output:
48         # If there's no output marker, treat the entire file as already pre-formatted.
49         _output = _input[:]
50     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
51
52
53 @contextmanager
54 def cache_dir(exists: bool = True) -> Iterator[Path]:
55     with TemporaryDirectory() as workspace:
56         cache_dir = Path(workspace)
57         if not exists:
58             cache_dir = cache_dir / "new"
59         cache_file = cache_dir / "cache.pkl"
60         with patch("black.CACHE_DIR", cache_dir), patch("black.CACHE_FILE", cache_file):
61             yield cache_dir
62
63
64 @contextmanager
65 def event_loop(close: bool) -> Iterator[None]:
66     policy = asyncio.get_event_loop_policy()
67     old_loop = policy.get_event_loop()
68     loop = policy.new_event_loop()
69     asyncio.set_event_loop(loop)
70     try:
71         yield
72
73     finally:
74         policy.set_event_loop(old_loop)
75         if close:
76             loop.close()
77
78
79 class BlackTestCase(unittest.TestCase):
80     maxDiff = None
81
82     def assertFormatEqual(self, expected: str, actual: str) -> None:
83         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
84             bdv: black.DebugVisitor[Any]
85             black.out("Expected tree:", fg="green")
86             try:
87                 exp_node = black.lib2to3_parse(expected)
88                 bdv = black.DebugVisitor()
89                 list(bdv.visit(exp_node))
90             except Exception as ve:
91                 black.err(str(ve))
92             black.out("Actual tree:", fg="red")
93             try:
94                 exp_node = black.lib2to3_parse(actual)
95                 bdv = black.DebugVisitor()
96                 list(bdv.visit(exp_node))
97             except Exception as ve:
98                 black.err(str(ve))
99         self.assertEqual(expected, actual)
100
101     @patch("black.dump_to_file", dump_to_stderr)
102     def test_self(self) -> None:
103         source, expected = read_data("test_black")
104         actual = fs(source)
105         self.assertFormatEqual(expected, actual)
106         black.assert_equivalent(source, actual)
107         black.assert_stable(source, actual, line_length=ll)
108         self.assertFalse(ff(THIS_FILE))
109
110     @patch("black.dump_to_file", dump_to_stderr)
111     def test_black(self) -> None:
112         source, expected = read_data("../black")
113         actual = fs(source)
114         self.assertFormatEqual(expected, actual)
115         black.assert_equivalent(source, actual)
116         black.assert_stable(source, actual, line_length=ll)
117         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
118
119     def test_piping(self) -> None:
120         source, expected = read_data("../black")
121         hold_stdin, hold_stdout = sys.stdin, sys.stdout
122         try:
123             sys.stdin, sys.stdout = StringIO(source), StringIO()
124             sys.stdin.name = "<stdin>"
125             black.format_stdin_to_stdout(
126                 line_length=ll, fast=True, write_back=black.WriteBack.YES
127             )
128             sys.stdout.seek(0)
129             actual = sys.stdout.read()
130         finally:
131             sys.stdin, sys.stdout = hold_stdin, hold_stdout
132         self.assertFormatEqual(expected, actual)
133         black.assert_equivalent(source, actual)
134         black.assert_stable(source, actual, line_length=ll)
135
136     def test_piping_diff(self) -> None:
137         source, _ = read_data("expression.py")
138         expected, _ = read_data("expression.diff")
139         hold_stdin, hold_stdout = sys.stdin, sys.stdout
140         try:
141             sys.stdin, sys.stdout = StringIO(source), StringIO()
142             sys.stdin.name = "<stdin>"
143             black.format_stdin_to_stdout(
144                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
145             )
146             sys.stdout.seek(0)
147             actual = sys.stdout.read()
148         finally:
149             sys.stdin, sys.stdout = hold_stdin, hold_stdout
150         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
151         self.assertEqual(expected, actual)
152
153     @patch("black.dump_to_file", dump_to_stderr)
154     def test_setup(self) -> None:
155         source, expected = read_data("../setup")
156         actual = fs(source)
157         self.assertFormatEqual(expected, actual)
158         black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, line_length=ll)
160         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
161
162     @patch("black.dump_to_file", dump_to_stderr)
163     def test_function(self) -> None:
164         source, expected = read_data("function")
165         actual = fs(source)
166         self.assertFormatEqual(expected, actual)
167         black.assert_equivalent(source, actual)
168         black.assert_stable(source, actual, line_length=ll)
169
170     @patch("black.dump_to_file", dump_to_stderr)
171     def test_expression(self) -> None:
172         source, expected = read_data("expression")
173         actual = fs(source)
174         self.assertFormatEqual(expected, actual)
175         black.assert_equivalent(source, actual)
176         black.assert_stable(source, actual, line_length=ll)
177
178     def test_expression_ff(self) -> None:
179         source, expected = read_data("expression")
180         tmp_file = Path(black.dump_to_file(source))
181         try:
182             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
183             with open(tmp_file, encoding="utf8") as f:
184                 actual = f.read()
185         finally:
186             os.unlink(tmp_file)
187         self.assertFormatEqual(expected, actual)
188         with patch("black.dump_to_file", dump_to_stderr):
189             black.assert_equivalent(source, actual)
190             black.assert_stable(source, actual, line_length=ll)
191
192     def test_expression_diff(self) -> None:
193         source, _ = read_data("expression.py")
194         expected, _ = read_data("expression.diff")
195         tmp_file = Path(black.dump_to_file(source))
196         hold_stdout = sys.stdout
197         try:
198             sys.stdout = StringIO()
199             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
200             sys.stdout.seek(0)
201             actual = sys.stdout.read()
202             actual = actual.replace(tmp_file.name, "<stdin>")
203         finally:
204             sys.stdout = hold_stdout
205             os.unlink(tmp_file)
206         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
207         if expected != actual:
208             dump = black.dump_to_file(actual)
209             msg = (
210                 f"Expected diff isn't equal to the actual. If you made changes "
211                 f"to expression.py and this is an anticipated difference, "
212                 f"overwrite tests/expression.diff with {dump}"
213             )
214             self.assertEqual(expected, actual, msg)
215
216     @patch("black.dump_to_file", dump_to_stderr)
217     def test_fstring(self) -> None:
218         source, expected = read_data("fstring")
219         actual = fs(source)
220         self.assertFormatEqual(expected, actual)
221         black.assert_equivalent(source, actual)
222         black.assert_stable(source, actual, line_length=ll)
223
224     @patch("black.dump_to_file", dump_to_stderr)
225     def test_string_quotes(self) -> None:
226         source, expected = read_data("string_quotes")
227         actual = fs(source)
228         self.assertFormatEqual(expected, actual)
229         black.assert_equivalent(source, actual)
230         black.assert_stable(source, actual, line_length=ll)
231
232     @patch("black.dump_to_file", dump_to_stderr)
233     def test_comments(self) -> None:
234         source, expected = read_data("comments")
235         actual = fs(source)
236         self.assertFormatEqual(expected, actual)
237         black.assert_equivalent(source, actual)
238         black.assert_stable(source, actual, line_length=ll)
239
240     @patch("black.dump_to_file", dump_to_stderr)
241     def test_comments2(self) -> None:
242         source, expected = read_data("comments2")
243         actual = fs(source)
244         self.assertFormatEqual(expected, actual)
245         black.assert_equivalent(source, actual)
246         black.assert_stable(source, actual, line_length=ll)
247
248     @patch("black.dump_to_file", dump_to_stderr)
249     def test_comments3(self) -> None:
250         source, expected = read_data("comments3")
251         actual = fs(source)
252         self.assertFormatEqual(expected, actual)
253         black.assert_equivalent(source, actual)
254         black.assert_stable(source, actual, line_length=ll)
255
256     @patch("black.dump_to_file", dump_to_stderr)
257     def test_comments4(self) -> None:
258         source, expected = read_data("comments4")
259         actual = fs(source)
260         self.assertFormatEqual(expected, actual)
261         black.assert_equivalent(source, actual)
262         black.assert_stable(source, actual, line_length=ll)
263
264     @patch("black.dump_to_file", dump_to_stderr)
265     def test_cantfit(self) -> None:
266         source, expected = read_data("cantfit")
267         actual = fs(source)
268         self.assertFormatEqual(expected, actual)
269         black.assert_equivalent(source, actual)
270         black.assert_stable(source, actual, line_length=ll)
271
272     @patch("black.dump_to_file", dump_to_stderr)
273     def test_import_spacing(self) -> None:
274         source, expected = read_data("import_spacing")
275         actual = fs(source)
276         self.assertFormatEqual(expected, actual)
277         black.assert_equivalent(source, actual)
278         black.assert_stable(source, actual, line_length=ll)
279
280     @patch("black.dump_to_file", dump_to_stderr)
281     def test_composition(self) -> None:
282         source, expected = read_data("composition")
283         actual = fs(source)
284         self.assertFormatEqual(expected, actual)
285         black.assert_equivalent(source, actual)
286         black.assert_stable(source, actual, line_length=ll)
287
288     @patch("black.dump_to_file", dump_to_stderr)
289     def test_empty_lines(self) -> None:
290         source, expected = read_data("empty_lines")
291         actual = fs(source)
292         self.assertFormatEqual(expected, actual)
293         black.assert_equivalent(source, actual)
294         black.assert_stable(source, actual, line_length=ll)
295
296     @patch("black.dump_to_file", dump_to_stderr)
297     def test_python2(self) -> None:
298         source, expected = read_data("python2")
299         actual = fs(source)
300         self.assertFormatEqual(expected, actual)
301         # black.assert_equivalent(source, actual)
302         black.assert_stable(source, actual, line_length=ll)
303
304     @patch("black.dump_to_file", dump_to_stderr)
305     def test_fmtonoff(self) -> None:
306         source, expected = read_data("fmtonoff")
307         actual = fs(source)
308         self.assertFormatEqual(expected, actual)
309         black.assert_equivalent(source, actual)
310         black.assert_stable(source, actual, line_length=ll)
311
312     def test_report(self) -> None:
313         report = black.Report()
314         out_lines = []
315         err_lines = []
316
317         def out(msg: str, **kwargs: Any) -> None:
318             out_lines.append(msg)
319
320         def err(msg: str, **kwargs: Any) -> None:
321             err_lines.append(msg)
322
323         with patch("black.out", out), patch("black.err", err):
324             report.done(Path("f1"), black.Changed.NO)
325             self.assertEqual(len(out_lines), 1)
326             self.assertEqual(len(err_lines), 0)
327             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
328             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
329             self.assertEqual(report.return_code, 0)
330             report.done(Path("f2"), black.Changed.YES)
331             self.assertEqual(len(out_lines), 2)
332             self.assertEqual(len(err_lines), 0)
333             self.assertEqual(out_lines[-1], "reformatted f2")
334             self.assertEqual(
335                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
336             )
337             report.done(Path("f3"), black.Changed.CACHED)
338             self.assertEqual(len(out_lines), 3)
339             self.assertEqual(len(err_lines), 0)
340             self.assertEqual(
341                 out_lines[-1], "f3 wasn't modified on disk since last run."
342             )
343             self.assertEqual(
344                 unstyle(str(report)), "1 file reformatted, 2 files left unchanged."
345             )
346             self.assertEqual(report.return_code, 0)
347             report.check = True
348             self.assertEqual(report.return_code, 1)
349             report.check = False
350             report.failed(Path("e1"), "boom")
351             self.assertEqual(len(out_lines), 3)
352             self.assertEqual(len(err_lines), 1)
353             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
354             self.assertEqual(
355                 unstyle(str(report)),
356                 "1 file reformatted, 2 files left unchanged, "
357                 "1 file failed to reformat.",
358             )
359             self.assertEqual(report.return_code, 123)
360             report.done(Path("f3"), black.Changed.YES)
361             self.assertEqual(len(out_lines), 4)
362             self.assertEqual(len(err_lines), 1)
363             self.assertEqual(out_lines[-1], "reformatted f3")
364             self.assertEqual(
365                 unstyle(str(report)),
366                 "2 files reformatted, 2 files left unchanged, "
367                 "1 file failed to reformat.",
368             )
369             self.assertEqual(report.return_code, 123)
370             report.failed(Path("e2"), "boom")
371             self.assertEqual(len(out_lines), 4)
372             self.assertEqual(len(err_lines), 2)
373             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
374             self.assertEqual(
375                 unstyle(str(report)),
376                 "2 files reformatted, 2 files left unchanged, "
377                 "2 files failed to reformat.",
378             )
379             self.assertEqual(report.return_code, 123)
380             report.done(Path("f4"), black.Changed.NO)
381             self.assertEqual(len(out_lines), 5)
382             self.assertEqual(len(err_lines), 2)
383             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
384             self.assertEqual(
385                 unstyle(str(report)),
386                 "2 files reformatted, 3 files left unchanged, "
387                 "2 files failed to reformat.",
388             )
389             self.assertEqual(report.return_code, 123)
390             report.check = True
391             self.assertEqual(
392                 unstyle(str(report)),
393                 "2 files would be reformatted, 3 files would be left unchanged, "
394                 "2 files would fail to reformat.",
395             )
396
397     def test_is_python36(self) -> None:
398         node = black.lib2to3_parse("def f(*, arg): ...\n")
399         self.assertFalse(black.is_python36(node))
400         node = black.lib2to3_parse("def f(*, arg,): ...\n")
401         self.assertTrue(black.is_python36(node))
402         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
403         self.assertTrue(black.is_python36(node))
404         source, expected = read_data("function")
405         node = black.lib2to3_parse(source)
406         self.assertTrue(black.is_python36(node))
407         node = black.lib2to3_parse(expected)
408         self.assertTrue(black.is_python36(node))
409         source, expected = read_data("expression")
410         node = black.lib2to3_parse(source)
411         self.assertFalse(black.is_python36(node))
412         node = black.lib2to3_parse(expected)
413         self.assertFalse(black.is_python36(node))
414
415     def test_debug_visitor(self) -> None:
416         source, _ = read_data("debug_visitor.py")
417         expected, _ = read_data("debug_visitor.out")
418         out_lines = []
419         err_lines = []
420
421         def out(msg: str, **kwargs: Any) -> None:
422             out_lines.append(msg)
423
424         def err(msg: str, **kwargs: Any) -> None:
425             err_lines.append(msg)
426
427         with patch("black.out", out), patch("black.err", err):
428             black.DebugVisitor.show(source)
429         actual = "\n".join(out_lines) + "\n"
430         log_name = ""
431         if expected != actual:
432             log_name = black.dump_to_file(*out_lines)
433         self.assertEqual(
434             expected,
435             actual,
436             f"AST print out is different. Actual version dumped to {log_name}",
437         )
438
439     def test_format_file_contents(self) -> None:
440         empty = ""
441         with self.assertRaises(black.NothingChanged):
442             black.format_file_contents(empty, line_length=ll, fast=False)
443         just_nl = "\n"
444         with self.assertRaises(black.NothingChanged):
445             black.format_file_contents(just_nl, line_length=ll, fast=False)
446         same = "l = [1, 2, 3]\n"
447         with self.assertRaises(black.NothingChanged):
448             black.format_file_contents(same, line_length=ll, fast=False)
449         different = "l = [1,2,3]"
450         expected = same
451         actual = black.format_file_contents(different, line_length=ll, fast=False)
452         self.assertEqual(expected, actual)
453         invalid = "return if you can"
454         with self.assertRaises(ValueError) as e:
455             black.format_file_contents(invalid, line_length=ll, fast=False)
456         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
457
458     def test_endmarker(self) -> None:
459         n = black.lib2to3_parse("\n")
460         self.assertEqual(n.type, black.syms.file_input)
461         self.assertEqual(len(n.children), 1)
462         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
463
464     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
465     def test_assertFormatEqual(self) -> None:
466         out_lines = []
467         err_lines = []
468
469         def out(msg: str, **kwargs: Any) -> None:
470             out_lines.append(msg)
471
472         def err(msg: str, **kwargs: Any) -> None:
473             err_lines.append(msg)
474
475         with patch("black.out", out), patch("black.err", err):
476             with self.assertRaises(AssertionError):
477                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
478
479         out_str = "".join(out_lines)
480         self.assertTrue("Expected tree:" in out_str)
481         self.assertTrue("Actual tree:" in out_str)
482         self.assertEqual("".join(err_lines), "")
483
484     def test_cache_broken_file(self) -> None:
485         with cache_dir() as workspace:
486             with black.CACHE_FILE.open("w") as fobj:
487                 fobj.write("this is not a pickle")
488             self.assertEqual(black.read_cache(), {})
489             src = (workspace / "test.py").resolve()
490             with src.open("w") as fobj:
491                 fobj.write("print('hello')")
492             result = CliRunner().invoke(black.main, [str(src)])
493             self.assertEqual(result.exit_code, 0)
494             cache = black.read_cache()
495             self.assertIn(src, cache)
496
497     def test_cache_single_file_already_cached(self) -> None:
498         with cache_dir() as workspace:
499             src = (workspace / "test.py").resolve()
500             with src.open("w") as fobj:
501                 fobj.write("print('hello')")
502             black.write_cache({}, [src])
503             result = CliRunner().invoke(black.main, [str(src)])
504             self.assertEqual(result.exit_code, 0)
505             with src.open("r") as fobj:
506                 self.assertEqual(fobj.read(), "print('hello')")
507
508     @event_loop(close=False)
509     def test_cache_multiple_files(self) -> None:
510         with cache_dir() as workspace:
511             one = (workspace / "one.py").resolve()
512             with one.open("w") as fobj:
513                 fobj.write("print('hello')")
514             two = (workspace / "two.py").resolve()
515             with two.open("w") as fobj:
516                 fobj.write("print('hello')")
517             black.write_cache({}, [one])
518             result = CliRunner().invoke(black.main, [str(workspace)])
519             self.assertEqual(result.exit_code, 0)
520             with one.open("r") as fobj:
521                 self.assertEqual(fobj.read(), "print('hello')")
522             with two.open("r") as fobj:
523                 self.assertEqual(fobj.read(), 'print("hello")\n')
524             cache = black.read_cache()
525             self.assertIn(one, cache)
526             self.assertIn(two, cache)
527
528     def test_no_cache_when_writeback_diff(self) -> None:
529         with cache_dir() as workspace:
530             src = (workspace / "test.py").resolve()
531             with src.open("w") as fobj:
532                 fobj.write("print('hello')")
533             result = CliRunner().invoke(black.main, [str(src), "--diff"])
534             self.assertEqual(result.exit_code, 0)
535             self.assertFalse(black.CACHE_FILE.exists())
536
537     def test_no_cache_when_stdin(self) -> None:
538         with cache_dir():
539             result = CliRunner().invoke(black.main, ["-"], input="print('hello')")
540             self.assertEqual(result.exit_code, 0)
541             self.assertFalse(black.CACHE_FILE.exists())
542
543     def test_read_cache_no_cachefile(self) -> None:
544         with cache_dir():
545             self.assertEqual(black.read_cache(), {})
546
547     def test_write_cache_read_cache(self) -> None:
548         with cache_dir() as workspace:
549             src = (workspace / "test.py").resolve()
550             src.touch()
551             black.write_cache({}, [src])
552             cache = black.read_cache()
553             self.assertIn(src, cache)
554             self.assertEqual(cache[src], black.get_cache_info(src))
555
556     def test_filter_cached(self) -> None:
557         with TemporaryDirectory() as workspace:
558             path = Path(workspace)
559             uncached = (path / "uncached").resolve()
560             cached = (path / "cached").resolve()
561             cached_but_changed = (path / "changed").resolve()
562             uncached.touch()
563             cached.touch()
564             cached_but_changed.touch()
565             cache = {cached: black.get_cache_info(cached), cached_but_changed: (0.0, 0)}
566             todo, done = black.filter_cached(
567                 cache, [uncached, cached, cached_but_changed]
568             )
569             self.assertEqual(todo, [uncached, cached_but_changed])
570             self.assertEqual(done, [cached])
571
572     def test_write_cache_creates_directory_if_needed(self) -> None:
573         with cache_dir(exists=False) as workspace:
574             self.assertFalse(workspace.exists())
575             black.write_cache({}, [])
576             self.assertTrue(workspace.exists())
577
578     @event_loop(close=False)
579     def test_failed_formatting_does_not_get_cached(self) -> None:
580         with cache_dir() as workspace:
581             failing = (workspace / "failing.py").resolve()
582             with failing.open("w") as fobj:
583                 fobj.write("not actually python")
584             clean = (workspace / "clean.py").resolve()
585             with clean.open("w") as fobj:
586                 fobj.write('print("hello")\n')
587             result = CliRunner().invoke(black.main, [str(workspace)])
588             self.assertEqual(result.exit_code, 123)
589             cache = black.read_cache()
590             self.assertNotIn(failing, cache)
591             self.assertIn(clean, cache)
592
593     def test_write_cache_write_fail(self) -> None:
594         with cache_dir(), patch.object(Path, "open") as mock:
595             mock.side_effect = OSError
596             black.write_cache({}, [])
597
598
599 if __name__ == "__main__":
600     unittest.main()