]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Add a description for the pre-commit hook (#107)
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 from functools import partial
3 from io import StringIO
4 import os
5 from pathlib import Path
6 import sys
7 from typing import Any, List, Tuple
8 import unittest
9 from unittest.mock import patch
10
11 from click import unstyle
12
13 import black
14
15 ll = 88
16 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
17 fs = partial(black.format_str, line_length=ll)
18 THIS_FILE = Path(__file__)
19 THIS_DIR = THIS_FILE.parent
20 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
21
22
23 def dump_to_stderr(*output: str) -> str:
24     return "\n" + "\n".join(output) + "\n"
25
26
27 def read_data(name: str) -> Tuple[str, str]:
28     """read_data('test_name') -> 'input', 'output'"""
29     if not name.endswith((".py", ".out", ".diff")):
30         name += ".py"
31     _input: List[str] = []
32     _output: List[str] = []
33     with open(THIS_DIR / name, "r", encoding="utf8") as test:
34         lines = test.readlines()
35     result = _input
36     for line in lines:
37         line = line.replace(EMPTY_LINE, "")
38         if line.rstrip() == "# output":
39             result = _output
40             continue
41
42         result.append(line)
43     if _input and not _output:
44         # If there's no output marker, treat the entire file as already pre-formatted.
45         _output = _input[:]
46     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
47
48
49 class BlackTestCase(unittest.TestCase):
50     maxDiff = None
51
52     def assertFormatEqual(self, expected: str, actual: str) -> None:
53         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
54             bdv: black.DebugVisitor[Any]
55             black.out("Expected tree:", fg="green")
56             try:
57                 exp_node = black.lib2to3_parse(expected)
58                 bdv = black.DebugVisitor()
59                 list(bdv.visit(exp_node))
60             except Exception as ve:
61                 black.err(str(ve))
62             black.out("Actual tree:", fg="red")
63             try:
64                 exp_node = black.lib2to3_parse(actual)
65                 bdv = black.DebugVisitor()
66                 list(bdv.visit(exp_node))
67             except Exception as ve:
68                 black.err(str(ve))
69         self.assertEqual(expected, actual)
70
71     @patch("black.dump_to_file", dump_to_stderr)
72     def test_self(self) -> None:
73         source, expected = read_data("test_black")
74         actual = fs(source)
75         self.assertFormatEqual(expected, actual)
76         black.assert_equivalent(source, actual)
77         black.assert_stable(source, actual, line_length=ll)
78         self.assertFalse(ff(THIS_FILE))
79
80     @patch("black.dump_to_file", dump_to_stderr)
81     def test_black(self) -> None:
82         source, expected = read_data("../black")
83         actual = fs(source)
84         self.assertFormatEqual(expected, actual)
85         black.assert_equivalent(source, actual)
86         black.assert_stable(source, actual, line_length=ll)
87         self.assertFalse(ff(THIS_DIR / ".." / "black.py"))
88
89     def test_piping(self) -> None:
90         source, expected = read_data("../black")
91         hold_stdin, hold_stdout = sys.stdin, sys.stdout
92         try:
93             sys.stdin, sys.stdout = StringIO(source), StringIO()
94             sys.stdin.name = "<stdin>"
95             black.format_stdin_to_stdout(
96                 line_length=ll, fast=True, write_back=black.WriteBack.YES
97             )
98             sys.stdout.seek(0)
99             actual = sys.stdout.read()
100         finally:
101             sys.stdin, sys.stdout = hold_stdin, hold_stdout
102         self.assertFormatEqual(expected, actual)
103         black.assert_equivalent(source, actual)
104         black.assert_stable(source, actual, line_length=ll)
105
106     def test_piping_diff(self) -> None:
107         source, _ = read_data("expression.py")
108         expected, _ = read_data("expression.diff")
109         hold_stdin, hold_stdout = sys.stdin, sys.stdout
110         try:
111             sys.stdin, sys.stdout = StringIO(source), StringIO()
112             sys.stdin.name = "<stdin>"
113             black.format_stdin_to_stdout(
114                 line_length=ll, fast=True, write_back=black.WriteBack.DIFF
115             )
116             sys.stdout.seek(0)
117             actual = sys.stdout.read()
118         finally:
119             sys.stdin, sys.stdout = hold_stdin, hold_stdout
120         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
121         self.assertEqual(expected, actual)
122
123     @patch("black.dump_to_file", dump_to_stderr)
124     def test_setup(self) -> None:
125         source, expected = read_data("../setup")
126         actual = fs(source)
127         self.assertFormatEqual(expected, actual)
128         black.assert_equivalent(source, actual)
129         black.assert_stable(source, actual, line_length=ll)
130         self.assertFalse(ff(THIS_DIR / ".." / "setup.py"))
131
132     @patch("black.dump_to_file", dump_to_stderr)
133     def test_function(self) -> None:
134         source, expected = read_data("function")
135         actual = fs(source)
136         self.assertFormatEqual(expected, actual)
137         black.assert_equivalent(source, actual)
138         black.assert_stable(source, actual, line_length=ll)
139
140     @patch("black.dump_to_file", dump_to_stderr)
141     def test_expression(self) -> None:
142         source, expected = read_data("expression")
143         actual = fs(source)
144         self.assertFormatEqual(expected, actual)
145         black.assert_equivalent(source, actual)
146         black.assert_stable(source, actual, line_length=ll)
147
148     def test_expression_ff(self) -> None:
149         source, expected = read_data("expression")
150         tmp_file = Path(black.dump_to_file(source))
151         try:
152             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.YES))
153             with open(tmp_file) as f:
154                 actual = f.read()
155         finally:
156             os.unlink(tmp_file)
157         self.assertFormatEqual(expected, actual)
158         with patch("black.dump_to_file", dump_to_stderr):
159             black.assert_equivalent(source, actual)
160             black.assert_stable(source, actual, line_length=ll)
161
162     def test_expression_diff(self) -> None:
163         source, _ = read_data("expression.py")
164         expected, _ = read_data("expression.diff")
165         tmp_file = Path(black.dump_to_file(source))
166         hold_stdout = sys.stdout
167         try:
168             sys.stdout = StringIO()
169             self.assertTrue(ff(tmp_file, write_back=black.WriteBack.DIFF))
170             sys.stdout.seek(0)
171             actual = sys.stdout.read()
172             actual = actual.replace(tmp_file.name, "<stdin>")
173         finally:
174             sys.stdout = hold_stdout
175             os.unlink(tmp_file)
176         actual = actual.rstrip() + "\n"  # the diff output has a trailing space
177         self.assertEqual(expected, actual)
178
179     @patch("black.dump_to_file", dump_to_stderr)
180     def test_fstring(self) -> None:
181         source, expected = read_data("fstring")
182         actual = fs(source)
183         self.assertFormatEqual(expected, actual)
184         black.assert_equivalent(source, actual)
185         black.assert_stable(source, actual, line_length=ll)
186
187     @patch("black.dump_to_file", dump_to_stderr)
188     def test_string_quotes(self) -> None:
189         source, expected = read_data("string_quotes")
190         actual = fs(source)
191         self.assertFormatEqual(expected, actual)
192         black.assert_equivalent(source, actual)
193         black.assert_stable(source, actual, line_length=ll)
194
195     @patch("black.dump_to_file", dump_to_stderr)
196     def test_comments(self) -> None:
197         source, expected = read_data("comments")
198         actual = fs(source)
199         self.assertFormatEqual(expected, actual)
200         black.assert_equivalent(source, actual)
201         black.assert_stable(source, actual, line_length=ll)
202
203     @patch("black.dump_to_file", dump_to_stderr)
204     def test_comments2(self) -> None:
205         source, expected = read_data("comments2")
206         actual = fs(source)
207         self.assertFormatEqual(expected, actual)
208         black.assert_equivalent(source, actual)
209         black.assert_stable(source, actual, line_length=ll)
210
211     @patch("black.dump_to_file", dump_to_stderr)
212     def test_comments3(self) -> None:
213         source, expected = read_data("comments3")
214         actual = fs(source)
215         self.assertFormatEqual(expected, actual)
216         black.assert_equivalent(source, actual)
217         black.assert_stable(source, actual, line_length=ll)
218
219     @patch("black.dump_to_file", dump_to_stderr)
220     def test_comments4(self) -> None:
221         source, expected = read_data("comments4")
222         actual = fs(source)
223         self.assertFormatEqual(expected, actual)
224         black.assert_equivalent(source, actual)
225         black.assert_stable(source, actual, line_length=ll)
226
227     @patch("black.dump_to_file", dump_to_stderr)
228     def test_cantfit(self) -> None:
229         source, expected = read_data("cantfit")
230         actual = fs(source)
231         self.assertFormatEqual(expected, actual)
232         black.assert_equivalent(source, actual)
233         black.assert_stable(source, actual, line_length=ll)
234
235     @patch("black.dump_to_file", dump_to_stderr)
236     def test_import_spacing(self) -> None:
237         source, expected = read_data("import_spacing")
238         actual = fs(source)
239         self.assertFormatEqual(expected, actual)
240         black.assert_equivalent(source, actual)
241         black.assert_stable(source, actual, line_length=ll)
242
243     @patch("black.dump_to_file", dump_to_stderr)
244     def test_composition(self) -> None:
245         source, expected = read_data("composition")
246         actual = fs(source)
247         self.assertFormatEqual(expected, actual)
248         black.assert_equivalent(source, actual)
249         black.assert_stable(source, actual, line_length=ll)
250
251     @patch("black.dump_to_file", dump_to_stderr)
252     def test_empty_lines(self) -> None:
253         source, expected = read_data("empty_lines")
254         actual = fs(source)
255         self.assertFormatEqual(expected, actual)
256         black.assert_equivalent(source, actual)
257         black.assert_stable(source, actual, line_length=ll)
258
259     @patch("black.dump_to_file", dump_to_stderr)
260     def test_python2(self) -> None:
261         source, expected = read_data("python2")
262         actual = fs(source)
263         self.assertFormatEqual(expected, actual)
264         # black.assert_equivalent(source, actual)
265         black.assert_stable(source, actual, line_length=ll)
266
267     @patch("black.dump_to_file", dump_to_stderr)
268     def test_fmtonoff(self) -> None:
269         source, expected = read_data("fmtonoff")
270         actual = fs(source)
271         self.assertFormatEqual(expected, actual)
272         black.assert_equivalent(source, actual)
273         black.assert_stable(source, actual, line_length=ll)
274
275     def test_report(self) -> None:
276         report = black.Report()
277         out_lines = []
278         err_lines = []
279
280         def out(msg: str, **kwargs: Any) -> None:
281             out_lines.append(msg)
282
283         def err(msg: str, **kwargs: Any) -> None:
284             err_lines.append(msg)
285
286         with patch("black.out", out), patch("black.err", err):
287             report.done(Path("f1"), changed=False)
288             self.assertEqual(len(out_lines), 1)
289             self.assertEqual(len(err_lines), 0)
290             self.assertEqual(out_lines[-1], "f1 already well formatted, good job.")
291             self.assertEqual(unstyle(str(report)), "1 file left unchanged.")
292             self.assertEqual(report.return_code, 0)
293             report.done(Path("f2"), changed=True)
294             self.assertEqual(len(out_lines), 2)
295             self.assertEqual(len(err_lines), 0)
296             self.assertEqual(out_lines[-1], "reformatted f2")
297             self.assertEqual(
298                 unstyle(str(report)), "1 file reformatted, 1 file left unchanged."
299             )
300             self.assertEqual(report.return_code, 0)
301             report.check = True
302             self.assertEqual(report.return_code, 1)
303             report.check = False
304             report.failed(Path("e1"), "boom")
305             self.assertEqual(len(out_lines), 2)
306             self.assertEqual(len(err_lines), 1)
307             self.assertEqual(err_lines[-1], "error: cannot format e1: boom")
308             self.assertEqual(
309                 unstyle(str(report)),
310                 "1 file reformatted, 1 file left unchanged, "
311                 "1 file failed to reformat.",
312             )
313             self.assertEqual(report.return_code, 123)
314             report.done(Path("f3"), changed=True)
315             self.assertEqual(len(out_lines), 3)
316             self.assertEqual(len(err_lines), 1)
317             self.assertEqual(out_lines[-1], "reformatted f3")
318             self.assertEqual(
319                 unstyle(str(report)),
320                 "2 files reformatted, 1 file left unchanged, "
321                 "1 file failed to reformat.",
322             )
323             self.assertEqual(report.return_code, 123)
324             report.failed(Path("e2"), "boom")
325             self.assertEqual(len(out_lines), 3)
326             self.assertEqual(len(err_lines), 2)
327             self.assertEqual(err_lines[-1], "error: cannot format e2: boom")
328             self.assertEqual(
329                 unstyle(str(report)),
330                 "2 files reformatted, 1 file left unchanged, "
331                 "2 files failed to reformat.",
332             )
333             self.assertEqual(report.return_code, 123)
334             report.done(Path("f4"), changed=False)
335             self.assertEqual(len(out_lines), 4)
336             self.assertEqual(len(err_lines), 2)
337             self.assertEqual(out_lines[-1], "f4 already well formatted, good job.")
338             self.assertEqual(
339                 unstyle(str(report)),
340                 "2 files reformatted, 2 files left unchanged, "
341                 "2 files failed to reformat.",
342             )
343             self.assertEqual(report.return_code, 123)
344             report.check = True
345             self.assertEqual(
346                 unstyle(str(report)),
347                 "2 files would be reformatted, 2 files would be left unchanged, "
348                 "2 files would fail to reformat.",
349             )
350
351     def test_is_python36(self) -> None:
352         node = black.lib2to3_parse("def f(*, arg): ...\n")
353         self.assertFalse(black.is_python36(node))
354         node = black.lib2to3_parse("def f(*, arg,): ...\n")
355         self.assertTrue(black.is_python36(node))
356         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
357         self.assertTrue(black.is_python36(node))
358         source, expected = read_data("function")
359         node = black.lib2to3_parse(source)
360         self.assertTrue(black.is_python36(node))
361         node = black.lib2to3_parse(expected)
362         self.assertTrue(black.is_python36(node))
363         source, expected = read_data("expression")
364         node = black.lib2to3_parse(source)
365         self.assertFalse(black.is_python36(node))
366         node = black.lib2to3_parse(expected)
367         self.assertFalse(black.is_python36(node))
368
369     def test_debug_visitor(self) -> None:
370         source, _ = read_data("debug_visitor.py")
371         expected, _ = read_data("debug_visitor.out")
372         out_lines = []
373         err_lines = []
374
375         def out(msg: str, **kwargs: Any) -> None:
376             out_lines.append(msg)
377
378         def err(msg: str, **kwargs: Any) -> None:
379             err_lines.append(msg)
380
381         with patch("black.out", out), patch("black.err", err):
382             black.DebugVisitor.show(source)
383         actual = "\n".join(out_lines) + "\n"
384         log_name = ""
385         if expected != actual:
386             log_name = black.dump_to_file(*out_lines)
387         self.assertEqual(
388             expected,
389             actual,
390             f"AST print out is different. Actual version dumped to {log_name}",
391         )
392
393     def test_format_file_contents(self) -> None:
394         empty = ""
395         with self.assertRaises(black.NothingChanged):
396             black.format_file_contents(empty, line_length=ll, fast=False)
397         just_nl = "\n"
398         with self.assertRaises(black.NothingChanged):
399             black.format_file_contents(just_nl, line_length=ll, fast=False)
400         same = "l = [1, 2, 3]\n"
401         with self.assertRaises(black.NothingChanged):
402             black.format_file_contents(same, line_length=ll, fast=False)
403         different = "l = [1,2,3]"
404         expected = same
405         actual = black.format_file_contents(different, line_length=ll, fast=False)
406         self.assertEqual(expected, actual)
407         invalid = "return if you can"
408         with self.assertRaises(ValueError) as e:
409             black.format_file_contents(invalid, line_length=ll, fast=False)
410         self.assertEqual(str(e.exception), "Cannot parse: 1:7: return if you can")
411
412     def test_endmarker(self) -> None:
413         n = black.lib2to3_parse("\n")
414         self.assertEqual(n.type, black.syms.file_input)
415         self.assertEqual(len(n.children), 1)
416         self.assertEqual(n.children[0].type, black.token.ENDMARKER)
417
418     @unittest.skipIf(os.environ.get("SKIP_AST_PRINT"), "user set SKIP_AST_PRINT")
419     def test_assertFormatEqual(self) -> None:
420         out_lines = []
421         err_lines = []
422
423         def out(msg: str, **kwargs: Any) -> None:
424             out_lines.append(msg)
425
426         def err(msg: str, **kwargs: Any) -> None:
427             err_lines.append(msg)
428
429         with patch("black.out", out), patch("black.err", err):
430             with self.assertRaises(AssertionError):
431                 self.assertFormatEqual("l = [1, 2, 3]", "l = [1, 2, 3,]")
432
433         out_str = "".join(out_lines)
434         self.assertTrue("Expected tree:" in out_str)
435         self.assertTrue("Actual tree:" in out_str)
436         self.assertEqual("".join(err_lines), "")
437
438
439 if __name__ == "__main__":
440     unittest.main()