]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Don't crash and burn on empty lines with trailing whitespace
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 from functools import partial
3 from io import StringIO
4 import os
5 from pathlib import Path
6 import sys
7 from typing import Any, List, Tuple
8 import unittest
9 from unittest.mock import patch
10
11 from click import unstyle
12
13 import black
14
15 ll = 88
16 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
17 fs = partial(black.format_str, line_length=ll)
18 THIS_FILE = Path(__file__)
19 THIS_DIR = THIS_FILE.parent
20 EMPTY_LINE = '# EMPTY LINE WITH WHITESPACE' + ' (this comment will be removed)'
21
22
23 def dump_to_stderr(*output: str) -> str:
24     return '\n' + '\n'.join(output) + '\n'
25
26
27 def read_data(name: str) -> Tuple[str, str]:
28     """read_data('test_name') -> 'input', 'output'"""
29     if not name.endswith(('.py', '.out')):
30         name += '.py'
31     _input: List[str] = []
32     _output: List[str] = []
33     with open(THIS_DIR / name, 'r', encoding='utf8') as test:
34         lines = test.readlines()
35     result = _input
36     for line in lines:
37         line = line.replace(EMPTY_LINE, '')
38         if line.rstrip() == '# output':
39             result = _output
40             continue
41
42         result.append(line)
43     if _input and not _output:
44         # If there's no output marker, treat the entire file as already pre-formatted.
45         _output = _input[:]
46     return ''.join(_input).strip() + '\n', ''.join(_output).strip() + '\n'
47
48
49 class BlackTestCase(unittest.TestCase):
50     maxDiff = None
51
52     def assertFormatEqual(self, expected: str, actual: str) -> None:
53         if actual != expected and not os.environ.get('SKIP_AST_PRINT'):
54             bdv: black.DebugVisitor[Any]
55             black.out('Expected tree:', fg='green')
56             try:
57                 exp_node = black.lib2to3_parse(expected)
58                 bdv = black.DebugVisitor()
59                 list(bdv.visit(exp_node))
60             except Exception as ve:
61                 black.err(str(ve))
62             black.out('Actual tree:', fg='red')
63             try:
64                 exp_node = black.lib2to3_parse(actual)
65                 bdv = black.DebugVisitor()
66                 list(bdv.visit(exp_node))
67             except Exception as ve:
68                 black.err(str(ve))
69         self.assertEqual(expected, actual)
70
71     @patch("black.dump_to_file", dump_to_stderr)
72     def test_self(self) -> None:
73         source, expected = read_data('test_black')
74         actual = fs(source)
75         self.assertFormatEqual(expected, actual)
76         black.assert_equivalent(source, actual)
77         black.assert_stable(source, actual, line_length=ll)
78         self.assertFalse(ff(THIS_FILE))
79
80     @patch("black.dump_to_file", dump_to_stderr)
81     def test_black(self) -> None:
82         source, expected = read_data('../black')
83         actual = fs(source)
84         self.assertFormatEqual(expected, actual)
85         black.assert_equivalent(source, actual)
86         black.assert_stable(source, actual, line_length=ll)
87         self.assertFalse(ff(THIS_DIR / '..' / 'black.py'))
88
89     def test_piping(self) -> None:
90         source, expected = read_data('../black')
91         hold_stdin, hold_stdout = sys.stdin, sys.stdout
92         try:
93             sys.stdin, sys.stdout = StringIO(source), StringIO()
94             sys.stdin.name = '<stdin>'
95             black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True)
96             sys.stdout.seek(0)
97             actual = sys.stdout.read()
98         finally:
99             sys.stdin, sys.stdout = hold_stdin, hold_stdout
100         self.assertFormatEqual(expected, actual)
101         black.assert_equivalent(source, actual)
102         black.assert_stable(source, actual, line_length=ll)
103
104     @patch("black.dump_to_file", dump_to_stderr)
105     def test_setup(self) -> None:
106         source, expected = read_data('../setup')
107         actual = fs(source)
108         self.assertFormatEqual(expected, actual)
109         black.assert_equivalent(source, actual)
110         black.assert_stable(source, actual, line_length=ll)
111         self.assertFalse(ff(THIS_DIR / '..' / 'setup.py'))
112
113     @patch("black.dump_to_file", dump_to_stderr)
114     def test_function(self) -> None:
115         source, expected = read_data('function')
116         actual = fs(source)
117         self.assertFormatEqual(expected, actual)
118         black.assert_equivalent(source, actual)
119         black.assert_stable(source, actual, line_length=ll)
120
121     @patch("black.dump_to_file", dump_to_stderr)
122     def test_expression(self) -> None:
123         source, expected = read_data('expression')
124         actual = fs(source)
125         self.assertFormatEqual(expected, actual)
126         black.assert_equivalent(source, actual)
127         black.assert_stable(source, actual, line_length=ll)
128
129     @patch("black.dump_to_file", dump_to_stderr)
130     def test_fstring(self) -> None:
131         source, expected = read_data('fstring')
132         actual = fs(source)
133         self.assertFormatEqual(expected, actual)
134         black.assert_equivalent(source, actual)
135         black.assert_stable(source, actual, line_length=ll)
136
137     @patch("black.dump_to_file", dump_to_stderr)
138     def test_comments(self) -> None:
139         source, expected = read_data('comments')
140         actual = fs(source)
141         self.assertFormatEqual(expected, actual)
142         black.assert_equivalent(source, actual)
143         black.assert_stable(source, actual, line_length=ll)
144
145     @patch("black.dump_to_file", dump_to_stderr)
146     def test_comments2(self) -> None:
147         source, expected = read_data('comments2')
148         actual = fs(source)
149         self.assertFormatEqual(expected, actual)
150         black.assert_equivalent(source, actual)
151         black.assert_stable(source, actual, line_length=ll)
152
153     @patch("black.dump_to_file", dump_to_stderr)
154     def test_cantfit(self) -> None:
155         source, expected = read_data('cantfit')
156         actual = fs(source)
157         self.assertFormatEqual(expected, actual)
158         black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, line_length=ll)
160
161     @patch("black.dump_to_file", dump_to_stderr)
162     def test_import_spacing(self) -> None:
163         source, expected = read_data('import_spacing')
164         actual = fs(source)
165         self.assertFormatEqual(expected, actual)
166         black.assert_equivalent(source, actual)
167         black.assert_stable(source, actual, line_length=ll)
168
169     @patch("black.dump_to_file", dump_to_stderr)
170     def test_composition(self) -> None:
171         source, expected = read_data('composition')
172         actual = fs(source)
173         self.assertFormatEqual(expected, actual)
174         black.assert_equivalent(source, actual)
175         black.assert_stable(source, actual, line_length=ll)
176
177     @patch("black.dump_to_file", dump_to_stderr)
178     def test_empty_lines(self) -> None:
179         source, expected = read_data('empty_lines')
180         actual = fs(source)
181         self.assertFormatEqual(expected, actual)
182         black.assert_equivalent(source, actual)
183         black.assert_stable(source, actual, line_length=ll)
184
185     @patch("black.dump_to_file", dump_to_stderr)
186     def test_python2(self) -> None:
187         source, expected = read_data('python2')
188         actual = fs(source)
189         self.assertFormatEqual(expected, actual)
190         # black.assert_equivalent(source, actual)
191         black.assert_stable(source, actual, line_length=ll)
192
193     @patch("black.dump_to_file", dump_to_stderr)
194     def test_fmtonoff(self) -> None:
195         source, expected = read_data('fmtonoff')
196         actual = fs(source)
197         self.assertFormatEqual(expected, actual)
198         black.assert_equivalent(source, actual)
199         black.assert_stable(source, actual, line_length=ll)
200
201     def test_report(self) -> None:
202         report = black.Report()
203         out_lines = []
204         err_lines = []
205
206         def out(msg: str, **kwargs: Any) -> None:
207             out_lines.append(msg)
208
209         def err(msg: str, **kwargs: Any) -> None:
210             err_lines.append(msg)
211
212         with patch("black.out", out), patch("black.err", err):
213             report.done(Path('f1'), changed=False)
214             self.assertEqual(len(out_lines), 1)
215             self.assertEqual(len(err_lines), 0)
216             self.assertEqual(out_lines[-1], 'f1 already well formatted, good job.')
217             self.assertEqual(unstyle(str(report)), '1 file left unchanged.')
218             self.assertEqual(report.return_code, 0)
219             report.done(Path('f2'), changed=True)
220             self.assertEqual(len(out_lines), 2)
221             self.assertEqual(len(err_lines), 0)
222             self.assertEqual(out_lines[-1], 'reformatted f2')
223             self.assertEqual(
224                 unstyle(str(report)), '1 file reformatted, 1 file left unchanged.'
225             )
226             self.assertEqual(report.return_code, 0)
227             report.check = True
228             self.assertEqual(report.return_code, 1)
229             report.check = False
230             report.failed(Path('e1'), 'boom')
231             self.assertEqual(len(out_lines), 2)
232             self.assertEqual(len(err_lines), 1)
233             self.assertEqual(err_lines[-1], 'error: cannot format e1: boom')
234             self.assertEqual(
235                 unstyle(str(report)),
236                 '1 file reformatted, 1 file left unchanged, '
237                 '1 file failed to reformat.',
238             )
239             self.assertEqual(report.return_code, 123)
240             report.done(Path('f3'), changed=True)
241             self.assertEqual(len(out_lines), 3)
242             self.assertEqual(len(err_lines), 1)
243             self.assertEqual(out_lines[-1], 'reformatted f3')
244             self.assertEqual(
245                 unstyle(str(report)),
246                 '2 files reformatted, 1 file left unchanged, '
247                 '1 file failed to reformat.',
248             )
249             self.assertEqual(report.return_code, 123)
250             report.failed(Path('e2'), 'boom')
251             self.assertEqual(len(out_lines), 3)
252             self.assertEqual(len(err_lines), 2)
253             self.assertEqual(err_lines[-1], 'error: cannot format e2: boom')
254             self.assertEqual(
255                 unstyle(str(report)),
256                 '2 files reformatted, 1 file left unchanged, '
257                 '2 files failed to reformat.',
258             )
259             self.assertEqual(report.return_code, 123)
260             report.done(Path('f4'), changed=False)
261             self.assertEqual(len(out_lines), 4)
262             self.assertEqual(len(err_lines), 2)
263             self.assertEqual(out_lines[-1], 'f4 already well formatted, good job.')
264             self.assertEqual(
265                 unstyle(str(report)),
266                 '2 files reformatted, 2 files left unchanged, '
267                 '2 files failed to reformat.',
268             )
269             self.assertEqual(report.return_code, 123)
270             report.check = True
271             self.assertEqual(
272                 unstyle(str(report)),
273                 '2 files would be reformatted, 2 files would be left unchanged, '
274                 '2 files would fail to reformat.',
275             )
276
277     def test_is_python36(self) -> None:
278         node = black.lib2to3_parse("def f(*, arg): ...\n")
279         self.assertFalse(black.is_python36(node))
280         node = black.lib2to3_parse("def f(*, arg,): ...\n")
281         self.assertTrue(black.is_python36(node))
282         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
283         self.assertTrue(black.is_python36(node))
284         source, expected = read_data('function')
285         node = black.lib2to3_parse(source)
286         self.assertTrue(black.is_python36(node))
287         node = black.lib2to3_parse(expected)
288         self.assertTrue(black.is_python36(node))
289         source, expected = read_data('expression')
290         node = black.lib2to3_parse(source)
291         self.assertFalse(black.is_python36(node))
292         node = black.lib2to3_parse(expected)
293         self.assertFalse(black.is_python36(node))
294
295     def test_debug_visitor(self) -> None:
296         source, _ = read_data('debug_visitor.py')
297         expected, _ = read_data('debug_visitor.out')
298         out_lines = []
299         err_lines = []
300
301         def out(msg: str, **kwargs: Any) -> None:
302             out_lines.append(msg)
303
304         def err(msg: str, **kwargs: Any) -> None:
305             err_lines.append(msg)
306
307         with patch("black.out", out), patch("black.err", err):
308             black.DebugVisitor.show(source)
309         actual = '\n'.join(out_lines) + '\n'
310         log_name = ''
311         if expected != actual:
312             log_name = black.dump_to_file(*out_lines)
313         self.assertEqual(
314             expected,
315             actual,
316             f"AST print out is different. Actual version dumped to {log_name}",
317         )
318
319
320 if __name__ == '__main__':
321     unittest.main()