]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Parse complex expressions in parameters after * and **
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 from functools import partial
3 from io import StringIO
4 import os
5 from pathlib import Path
6 import sys
7 from typing import Any, List, Tuple
8 import unittest
9 from unittest.mock import patch
10
11 from click import unstyle
12
13 import black
14
15 ll = 88
16 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
17 fs = partial(black.format_str, line_length=ll)
18 THIS_FILE = Path(__file__)
19 THIS_DIR = THIS_FILE.parent
20 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
21
22
23 def dump_to_stderr(*output: str) -> str:
24     return "\n" + "\n".join(output) + "\n"
25
26
27 def read_data(name: str) -> Tuple[str, str]:
28     """read_data('test_name') -> 'input', 'output'"""
29     if not name.endswith((".py", ".out", ".diff")):
30         name += ".py"
31     _input: List[str] = []
32     _output: List[str] = []
33     with open(THIS_DIR / name, "r", encoding="utf8") as test:
34         lines = test.readlines()
35     result = _input
36     for line in lines:
37         line = line.replace(EMPTY_LINE, "")
38         if line.rstrip() == "# output":
39             result = _output
40             continue
41
42         result.append(line)
43     if _input and not _output:
44         # If there's no output marker, treat the entire file as already pre-formatted.
45         _output = _input[:]
46     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
47
48
49 class BlackTestCase(unittest.TestCase):
50     maxDiff = None
51
52     def assertFormatEqual(self, expected: str, actual: str) -> None:
53         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
54             bdv: black.DebugVisitor[Any]
55             black.out("Expected tree:", fg="green")
56             try:
57                 exp_node = black.lib2to3_parse(expected)
58                 bdv = black.DebugVisitor()
59                 list(bdv.visit(exp_node))
60             except Exception as ve:
61                 black.err(str(ve))
62             black.out("Actual tree:", fg="red")
63             try:
64                 exp_node = black.lib2to3_parse(actual)
65                 bdv = black.DebugVisitor()
66                 list(bdv.visit(exp_node))
67             except Exception as ve:
68                 black.err(str(ve))
69         self.assertEqual(expected, actual)
70
71     @patch("black.dump_to_file", dump_to_stderr)
72     def test_self(self) -> None:
73         source, expected = read_data("test_black")
74         actual = fs(source)
75         self.assertFormatEqual(expected, actual)
76         black.assert_equivalent(source, actual)
77         black.assert_stable(source, actual, line_length=ll)
78         self.assertFalse(ff(THIS_FILE))
79
80     @patch("black.dump_to_file", dump_to_stderr)
81     def test_black(self) -> None:
82         source, expected = read_data("../black")
83         actual = fs(source)
84         self.assertFormatEqual(expected, actual)
85         black.assert_equivalent(source, actual)
86         black.assert_stable(source, actual, line_length=ll)
87         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
88
89     def test_piping(self) -> None:
90         source, expected = read_data("../black")
91         hold_stdin, hold_stdout = sys.stdin, sys.stdout
92         try:
93             sys.stdin, sys.stdout = StringIO(source), StringIO()
94             sys.stdin.name = "<stdin>"
95             black.format_stdin_to_stdout(
96                 line_length=ll, fast=True, write_back=black.WriteBack.YES
97             )
98             sys.stdout.seek(0)
99             actual = sys.stdout.read()
100         finally:
101             sys.stdin, sys.stdout = hold_stdin, hold_stdout
102         self.assertFormatEqual(expected, actual)
103         black.assert_equivalent(source, actual)
104         black.assert_stable(source, actual, line_length=ll)
105
106     def test_piping_diff(self) -> None:
107         source, _ = read_data("expression.py")
108         expected, _ = read_data("expression.diff")
109         hold_stdin, hold_stdout = sys.stdin, sys.stdout
110         try:
111             sys.stdin, sys.stdout = StringIO(source), StringIO()
112             sys.stdin.name = "<stdin>"
113             black.format_stdin_to_stdout(
114                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
115             )
116             sys.stdout.seek(0)
117             actual = sys.stdout.read()
118         finally:
119             sys.stdin, sys.stdout = hold_stdin, hold_stdout
120         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
121         self.assertEqual(expected, actual)
122
123     @patch("black.dump_to_file", dump_to_stderr)
124     def test_setup(self) -> None:
125         source, expected = read_data("../setup")
126         actual = fs(source)
127         self.assertFormatEqual(expected, actual)
128         black.assert_equivalent(source, actual)
129         black.assert_stable(source, actual, line_length=ll)
130         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
131
132     @patch("black.dump_to_file", dump_to_stderr)
133     def test_function(self) -> None:
134         source, expected = read_data("function")
135         actual = fs(source)
136         self.assertFormatEqual(expected, actual)
137         black.assert_equivalent(source, actual)
138         black.assert_stable(source, actual, line_length=ll)
139
140     @patch("black.dump_to_file", dump_to_stderr)
141     def test_expression(self) -> None:
142         source, expected = read_data("expression")
143         actual = fs(source)
144         self.assertFormatEqual(expected, actual)
145         black.assert_equivalent(source, actual)
146         black.assert_stable(source, actual, line_length=ll)
147
148     def test_expression_ff(self) -> None:
149         source, expected = read_data("expression")
150         tmp_file = Path(black.dump_to_file(source))
151         try:
152             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
153             with open(tmp_file, encoding="utf8") as f:
154                 actual = f.read()
155         finally:
156             os.unlink(tmp_file)
157         self.assertFormatEqual(expected, actual)
158         with patch("black.dump_to_file", dump_to_stderr):
159             black.assert_equivalent(source, actual)
160             black.assert_stable(source, actual, line_length=ll)
161
162     def test_expression_diff(self) -> None:
163         source, _ = read_data("expression.py")
164         expected, _ = read_data("expression.diff")
165         tmp_file = Path(black.dump_to_file(source))
166         hold_stdout = sys.stdout
167         try:
168             sys.stdout = StringIO()
169             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
170             sys.stdout.seek(0)
171             actual = sys.stdout.read()
172             actual = actual.replace(tmp_file.name, "<stdin>")
173         finally:
174             sys.stdout = hold_stdout
175             os.unlink(tmp_file)
176         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
177         if expected != actual:
178             dump = black.dump_to_file(actual)
179             msg = (
180                 f"Expected diff isn't equal to the actual. If you made changes "
181                 f"to expression.py and this is an anticipated difference, "
182                 f"overwrite tests/expression.diff with {dump}"
183             )
184             self.assertEqual(expected, actual, msg)
185
186     @patch("black.dump_to_file", dump_to_stderr)
187     def test_fstring(self) -> None:
188         source, expected = read_data("fstring")
189         actual = fs(source)
190         self.assertFormatEqual(expected, actual)
191         black.assert_equivalent(source, actual)
192         black.assert_stable(source, actual, line_length=ll)
193
194     @patch("black.dump_to_file", dump_to_stderr)
195     def test_string_quotes(self) -> None:
196         source, expected = read_data("string_quotes")
197         actual = fs(source)
198         self.assertFormatEqual(expected, actual)
199         black.assert_equivalent(source, actual)
200         black.assert_stable(source, actual, line_length=ll)
201
202     @patch("black.dump_to_file", dump_to_stderr)
203     def test_comments(self) -> None:
204         source, expected = read_data("comments")
205         actual = fs(source)
206         self.assertFormatEqual(expected, actual)
207         black.assert_equivalent(source, actual)
208         black.assert_stable(source, actual, line_length=ll)
209
210     @patch("black.dump_to_file", dump_to_stderr)
211     def test_comments2(self) -> None:
212         source, expected = read_data("comments2")
213         actual = fs(source)
214         self.assertFormatEqual(expected, actual)
215         black.assert_equivalent(source, actual)
216         black.assert_stable(source, actual, line_length=ll)
217
218     @patch("black.dump_to_file", dump_to_stderr)
219     def test_comments3(self) -> None:
220         source, expected = read_data("comments3")
221         actual = fs(source)
222         self.assertFormatEqual(expected, actual)
223         black.assert_equivalent(source, actual)
224         black.assert_stable(source, actual, line_length=ll)
225
226     @patch("black.dump_to_file", dump_to_stderr)
227     def test_comments4(self) -> None:
228         source, expected = read_data("comments4")
229         actual = fs(source)
230         self.assertFormatEqual(expected, actual)
231         black.assert_equivalent(source, actual)
232         black.assert_stable(source, actual, line_length=ll)
233
234     @patch("black.dump_to_file", dump_to_stderr)
235     def test_cantfit(self) -> None:
236         source, expected = read_data("cantfit")
237         actual = fs(source)
238         self.assertFormatEqual(expected, actual)
239         black.assert_equivalent(source, actual)
240         black.assert_stable(source, actual, line_length=ll)
241
242     @patch("black.dump_to_file", dump_to_stderr)
243     def test_import_spacing(self) -> None:
244         source, expected = read_data("import_spacing")
245         actual = fs(source)
246         self.assertFormatEqual(expected, actual)
247         black.assert_equivalent(source, actual)
248         black.assert_stable(source, actual, line_length=ll)
249
250     @patch("black.dump_to_file", dump_to_stderr)
251     def test_composition(self) -> None:
252         source, expected = read_data("composition")
253         actual = fs(source)
254         self.assertFormatEqual(expected, actual)
255         black.assert_equivalent(source, actual)
256         black.assert_stable(source, actual, line_length=ll)
257
258     @patch("black.dump_to_file", dump_to_stderr)
259     def test_empty_lines(self) -> None:
260         source, expected = read_data("empty_lines")
261         actual = fs(source)
262         self.assertFormatEqual(expected, actual)
263         black.assert_equivalent(source, actual)
264         black.assert_stable(source, actual, line_length=ll)
265
266     @patch("black.dump_to_file", dump_to_stderr)
267     def test_python2(self) -> None:
268         source, expected = read_data("python2")
269         actual = fs(source)
270         self.assertFormatEqual(expected, actual)
271         # black.assert_equivalent(source, actual)
272         black.assert_stable(source, actual, line_length=ll)
273
274     @patch("black.dump_to_file", dump_to_stderr)
275     def test_fmtonoff(self) -> None:
276         source, expected = read_data("fmtonoff")
277         actual = fs(source)
278         self.assertFormatEqual(expected, actual)
279         black.assert_equivalent(source, actual)
280         black.assert_stable(source, actual, line_length=ll)
281
282     def test_report(self) -> None:
283         report = black.Report()
284         out_lines = []
285         err_lines = []
286
287         def out(msg: str, **kwargs: Any) -> None:
288             out_lines.append(msg)
289
290         def err(msg: str, **kwargs: Any) -> None:
291             err_lines.append(msg)
292
293         with patch("black.out", out), patch("black.err", err):
294             report.done(Path("f1"), changed=False)
295             self.assertEqual(len(out_lines), 1)
296             self.assertEqual(len(err_lines), 0)
297             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
298             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
299             self.assertEqual(report.return_code, 0)
300             report.done(Path("f2"), changed=True)
301             self.assertEqual(len(out_lines), 2)
302             self.assertEqual(len(err_lines), 0)
303             self.assertEqual(out_lines[-1], "reformatted f2")
304             self.assertEqual(
305                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
306             )
307             self.assertEqual(report.return_code, 0)
308             report.check = True
309             self.assertEqual(report.return_code, 1)
310             report.check = False
311             report.failed(Path("e1"), "boom")
312             self.assertEqual(len(out_lines), 2)
313             self.assertEqual(len(err_lines), 1)
314             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
315             self.assertEqual(
316                 unstyle(str(report)),
317                 "1 file reformatted, 1 file left unchanged, "
318                 "1 file failed to reformat.",
319             )
320             self.assertEqual(report.return_code, 123)
321             report.done(Path("f3"), changed=True)
322             self.assertEqual(len(out_lines), 3)
323             self.assertEqual(len(err_lines), 1)
324             self.assertEqual(out_lines[-1], "reformatted f3")
325             self.assertEqual(
326                 unstyle(str(report)),
327                 "2 files reformatted, 1 file left unchanged, "
328                 "1 file failed to reformat.",
329             )
330             self.assertEqual(report.return_code, 123)
331             report.failed(Path("e2"), "boom")
332             self.assertEqual(len(out_lines), 3)
333             self.assertEqual(len(err_lines), 2)
334             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
335             self.assertEqual(
336                 unstyle(str(report)),
337                 "2 files reformatted, 1 file left unchanged, "
338                 "2 files failed to reformat.",
339             )
340             self.assertEqual(report.return_code, 123)
341             report.done(Path("f4"), changed=False)
342             self.assertEqual(len(out_lines), 4)
343             self.assertEqual(len(err_lines), 2)
344             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
345             self.assertEqual(
346                 unstyle(str(report)),
347                 "2 files reformatted, 2 files left unchanged, "
348                 "2 files failed to reformat.",
349             )
350             self.assertEqual(report.return_code, 123)
351             report.check = True
352             self.assertEqual(
353                 unstyle(str(report)),
354                 "2 files would be reformatted, 2 files would be left unchanged, "
355                 "2 files would fail to reformat.",
356             )
357
358     def test_is_python36(self) -> None:
359         node = black.lib2to3_parse("def f(*, arg): ...\n")
360         self.assertFalse(black.is_python36(node))
361         node = black.lib2to3_parse("def f(*, arg,): ...\n")
362         self.assertTrue(black.is_python36(node))
363         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
364         self.assertTrue(black.is_python36(node))
365         source, expected = read_data("function")
366         node = black.lib2to3_parse(source)
367         self.assertTrue(black.is_python36(node))
368         node = black.lib2to3_parse(expected)
369         self.assertTrue(black.is_python36(node))
370         source, expected = read_data("expression")
371         node = black.lib2to3_parse(source)
372         self.assertFalse(black.is_python36(node))
373         node = black.lib2to3_parse(expected)
374         self.assertFalse(black.is_python36(node))
375
376     def test_debug_visitor(self) -> None:
377         source, _ = read_data("debug_visitor.py")
378         expected, _ = read_data("debug_visitor.out")
379         out_lines = []
380         err_lines = []
381
382         def out(msg: str, **kwargs: Any) -> None:
383             out_lines.append(msg)
384
385         def err(msg: str, **kwargs: Any) -> None:
386             err_lines.append(msg)
387
388         with patch("black.out", out), patch("black.err", err):
389             black.DebugVisitor.show(source)
390         actual = "\n".join(out_lines) + "\n"
391         log_name = ""
392         if expected != actual:
393             log_name = black.dump_to_file(*out_lines)
394         self.assertEqual(
395             expected,
396             actual,
397             f"AST print out is different. Actual version dumped to {log_name}",
398         )
399
400     def test_format_file_contents(self) -> None:
401         empty = ""
402         with self.assertRaises(black.NothingChanged):
403             black.format_file_contents(empty, line_length=ll, fast=False)
404         just_nl = "\n"
405         with self.assertRaises(black.NothingChanged):
406             black.format_file_contents(just_nl, line_length=ll, fast=False)
407         same = "l = [1, 2, 3]\n"
408         with self.assertRaises(black.NothingChanged):
409             black.format_file_contents(same, line_length=ll, fast=False)
410         different = "l = [1,2,3]"
411         expected = same
412         actual = black.format_file_contents(different, line_length=ll, fast=False)
413         self.assertEqual(expected, actual)
414         invalid = "return if you can"
415         with self.assertRaises(ValueError) as e:
416             black.format_file_contents(invalid, line_length=ll, fast=False)
417         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
418
419     def test_endmarker(self) -> None:
420         n = black.lib2to3_parse("\n")
421         self.assertEqual(n.type, black.syms.file_input)
422         self.assertEqual(len(n.children), 1)
423         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
424
425     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
426     def test_assertFormatEqual(self) -> None:
427         out_lines = []
428         err_lines = []
429
430         def out(msg: str, **kwargs: Any) -> None:
431             out_lines.append(msg)
432
433         def err(msg: str, **kwargs: Any) -> None:
434             err_lines.append(msg)
435
436         with patch("black.out", out), patch("black.err", err):
437             with self.assertRaises(AssertionError):
438                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
439
440         out_str = "".join(out_lines)
441         self.assertTrue("Expected tree:" in out_str)
442         self.assertTrue("Actual tree:" in out_str)
443         self.assertEqual("".join(err_lines), "")
444
445
446 if __name__ == "__main__":
447     unittest.main()