]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

The site is cleaner without the 'Related' cruft.
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 from functools import partial
3 from io import StringIO
4 import os
5 from pathlib import Path
6 import sys
7 from typing import Any, List, Tuple
8 import unittest
9 from unittest.mock import patch
10
11 from click import unstyle
12
13 import black
14
15 ll = 88
16 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
17 fs = partial(black.format_str, line_length=ll)
18 THIS_FILE = Path(__file__)
19 THIS_DIR = THIS_FILE.parent
20
21
22 def dump_to_stderr(*output: str) -> str:
23     return '\n' + '\n'.join(output) + '\n'
24
25
26 def read_data(name: str) -> Tuple[str, str]:
27     """read_data('test_name') -> 'input', 'output'"""
28     if not name.endswith(('.py', '.out')):
29         name += '.py'
30     _input: List[str] = []
31     _output: List[str] = []
32     with open(THIS_DIR / name, 'r', encoding='utf8') as test:
33         lines = test.readlines()
34     result = _input
35     for line in lines:
36         if line.rstrip() == '# output':
37             result = _output
38             continue
39
40         result.append(line)
41     if _input and not _output:
42         # If there's no output marker, treat the entire file as already pre-formatted.
43         _output = _input[:]
44     return ''.join(_input).strip() + '\n', ''.join(_output).strip() + '\n'
45
46
47 class BlackTestCase(unittest.TestCase):
48     maxDiff = None
49
50     def assertFormatEqual(self, expected: str, actual: str) -> None:
51         if actual != expected and not os.environ.get('SKIP_AST_PRINT'):
52             bdv: black.DebugVisitor[Any]
53             black.out('Expected tree:', fg='green')
54             try:
55                 exp_node = black.lib2to3_parse(expected)
56                 bdv = black.DebugVisitor()
57                 list(bdv.visit(exp_node))
58             except Exception as ve:
59                 black.err(str(ve))
60             black.out('Actual tree:', fg='red')
61             try:
62                 exp_node = black.lib2to3_parse(actual)
63                 bdv = black.DebugVisitor()
64                 list(bdv.visit(exp_node))
65             except Exception as ve:
66                 black.err(str(ve))
67         self.assertEqual(expected, actual)
68
69     @patch("black.dump_to_file", dump_to_stderr)
70     def test_self(self) -> None:
71         source, expected = read_data('test_black')
72         actual = fs(source)
73         self.assertFormatEqual(expected, actual)
74         black.assert_equivalent(source, actual)
75         black.assert_stable(source, actual, line_length=ll)
76         self.assertFalse(ff(THIS_FILE))
77
78     @patch("black.dump_to_file", dump_to_stderr)
79     def test_black(self) -> None:
80         source, expected = read_data('../black')
81         actual = fs(source)
82         self.assertFormatEqual(expected, actual)
83         black.assert_equivalent(source, actual)
84         black.assert_stable(source, actual, line_length=ll)
85         self.assertFalse(ff(THIS_DIR / '..' / 'black.py'))
86
87     def test_piping(self) -> None:
88         source, expected = read_data('../black')
89         hold_stdin, hold_stdout = sys.stdin, sys.stdout
90         try:
91             sys.stdin, sys.stdout = StringIO(source), StringIO()
92             sys.stdin.name = '<stdin>'
93             black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True)
94             sys.stdout.seek(0)
95             actual = sys.stdout.read()
96         finally:
97             sys.stdin, sys.stdout = hold_stdin, hold_stdout
98         self.assertFormatEqual(expected, actual)
99         black.assert_equivalent(source, actual)
100         black.assert_stable(source, actual, line_length=ll)
101
102     @patch("black.dump_to_file", dump_to_stderr)
103     def test_setup(self) -> None:
104         source, expected = read_data('../setup')
105         actual = fs(source)
106         self.assertFormatEqual(expected, actual)
107         black.assert_equivalent(source, actual)
108         black.assert_stable(source, actual, line_length=ll)
109         self.assertFalse(ff(THIS_DIR / '..' / 'setup.py'))
110
111     @patch("black.dump_to_file", dump_to_stderr)
112     def test_function(self) -> None:
113         source, expected = read_data('function')
114         actual = fs(source)
115         self.assertFormatEqual(expected, actual)
116         black.assert_equivalent(source, actual)
117         black.assert_stable(source, actual, line_length=ll)
118
119     @patch("black.dump_to_file", dump_to_stderr)
120     def test_expression(self) -> None:
121         source, expected = read_data('expression')
122         actual = fs(source)
123         self.assertFormatEqual(expected, actual)
124         black.assert_equivalent(source, actual)
125         black.assert_stable(source, actual, line_length=ll)
126
127     @patch("black.dump_to_file", dump_to_stderr)
128     def test_fstring(self) -> None:
129         source, expected = read_data('fstring')
130         actual = fs(source)
131         self.assertFormatEqual(expected, actual)
132         black.assert_equivalent(source, actual)
133         black.assert_stable(source, actual, line_length=ll)
134
135     @patch("black.dump_to_file", dump_to_stderr)
136     def test_comments(self) -> None:
137         source, expected = read_data('comments')
138         actual = fs(source)
139         self.assertFormatEqual(expected, actual)
140         black.assert_equivalent(source, actual)
141         black.assert_stable(source, actual, line_length=ll)
142
143     @patch("black.dump_to_file", dump_to_stderr)
144     def test_comments2(self) -> None:
145         source, expected = read_data('comments2')
146         actual = fs(source)
147         self.assertFormatEqual(expected, actual)
148         black.assert_equivalent(source, actual)
149         black.assert_stable(source, actual, line_length=ll)
150
151     @patch("black.dump_to_file", dump_to_stderr)
152     def test_cantfit(self) -> None:
153         source, expected = read_data('cantfit')
154         actual = fs(source)
155         self.assertFormatEqual(expected, actual)
156         black.assert_equivalent(source, actual)
157         black.assert_stable(source, actual, line_length=ll)
158
159     @patch("black.dump_to_file", dump_to_stderr)
160     def test_import_spacing(self) -> None:
161         source, expected = read_data('import_spacing')
162         actual = fs(source)
163         self.assertFormatEqual(expected, actual)
164         black.assert_equivalent(source, actual)
165         black.assert_stable(source, actual, line_length=ll)
166
167     @patch("black.dump_to_file", dump_to_stderr)
168     def test_composition(self) -> None:
169         source, expected = read_data('composition')
170         actual = fs(source)
171         self.assertFormatEqual(expected, actual)
172         black.assert_equivalent(source, actual)
173         black.assert_stable(source, actual, line_length=ll)
174
175     @patch("black.dump_to_file", dump_to_stderr)
176     def test_empty_lines(self) -> None:
177         source, expected = read_data('empty_lines')
178         actual = fs(source)
179         self.assertFormatEqual(expected, actual)
180         black.assert_equivalent(source, actual)
181         black.assert_stable(source, actual, line_length=ll)
182
183     @patch("black.dump_to_file", dump_to_stderr)
184     def test_python2(self) -> None:
185         source, expected = read_data('python2')
186         actual = fs(source)
187         self.assertFormatEqual(expected, actual)
188         # black.assert_equivalent(source, actual)
189         black.assert_stable(source, actual, line_length=ll)
190
191     @patch("black.dump_to_file", dump_to_stderr)
192     def test_fmtonoff(self) -> None:
193         source, expected = read_data('fmtonoff')
194         actual = fs(source)
195         self.assertFormatEqual(expected, actual)
196         black.assert_equivalent(source, actual)
197         black.assert_stable(source, actual, line_length=ll)
198
199     def test_report(self) -> None:
200         report = black.Report()
201         out_lines = []
202         err_lines = []
203
204         def out(msg: str, **kwargs: Any) -> None:
205             out_lines.append(msg)
206
207         def err(msg: str, **kwargs: Any) -> None:
208             err_lines.append(msg)
209
210         with patch("black.out", out), patch("black.err", err):
211             report.done(Path('f1'), changed=False)
212             self.assertEqual(len(out_lines), 1)
213             self.assertEqual(len(err_lines), 0)
214             self.assertEqual(out_lines[-1], 'f1 already well formatted, good job.')
215             self.assertEqual(unstyle(str(report)), '1 file left unchanged.')
216             self.assertEqual(report.return_code, 0)
217             report.done(Path('f2'), changed=True)
218             self.assertEqual(len(out_lines), 2)
219             self.assertEqual(len(err_lines), 0)
220             self.assertEqual(out_lines[-1], 'reformatted f2')
221             self.assertEqual(
222                 unstyle(str(report)), '1 file reformatted, 1 file left unchanged.'
223             )
224             self.assertEqual(report.return_code, 0)
225             report.check = True
226             self.assertEqual(report.return_code, 1)
227             report.check = False
228             report.failed(Path('e1'), 'boom')
229             self.assertEqual(len(out_lines), 2)
230             self.assertEqual(len(err_lines), 1)
231             self.assertEqual(err_lines[-1], 'error: cannot format e1: boom')
232             self.assertEqual(
233                 unstyle(str(report)),
234                 '1 file reformatted, 1 file left unchanged, '
235                 '1 file failed to reformat.',
236             )
237             self.assertEqual(report.return_code, 123)
238             report.done(Path('f3'), changed=True)
239             self.assertEqual(len(out_lines), 3)
240             self.assertEqual(len(err_lines), 1)
241             self.assertEqual(out_lines[-1], 'reformatted f3')
242             self.assertEqual(
243                 unstyle(str(report)),
244                 '2 files reformatted, 1 file left unchanged, '
245                 '1 file failed to reformat.',
246             )
247             self.assertEqual(report.return_code, 123)
248             report.failed(Path('e2'), 'boom')
249             self.assertEqual(len(out_lines), 3)
250             self.assertEqual(len(err_lines), 2)
251             self.assertEqual(err_lines[-1], 'error: cannot format e2: boom')
252             self.assertEqual(
253                 unstyle(str(report)),
254                 '2 files reformatted, 1 file left unchanged, '
255                 '2 files failed to reformat.',
256             )
257             self.assertEqual(report.return_code, 123)
258             report.done(Path('f4'), changed=False)
259             self.assertEqual(len(out_lines), 4)
260             self.assertEqual(len(err_lines), 2)
261             self.assertEqual(out_lines[-1], 'f4 already well formatted, good job.')
262             self.assertEqual(
263                 unstyle(str(report)),
264                 '2 files reformatted, 2 files left unchanged, '
265                 '2 files failed to reformat.',
266             )
267             self.assertEqual(report.return_code, 123)
268             report.check = True
269             self.assertEqual(
270                 unstyle(str(report)),
271                 '2 files would be reformatted, 2 files would be left unchanged, '
272                 '2 files would fail to reformat.',
273             )
274
275     def test_is_python36(self) -> None:
276         node = black.lib2to3_parse("def f(*, arg): ...\n")
277         self.assertFalse(black.is_python36(node))
278         node = black.lib2to3_parse("def f(*, arg,): ...\n")
279         self.assertTrue(black.is_python36(node))
280         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
281         self.assertTrue(black.is_python36(node))
282         source, expected = read_data('function')
283         node = black.lib2to3_parse(source)
284         self.assertTrue(black.is_python36(node))
285         node = black.lib2to3_parse(expected)
286         self.assertTrue(black.is_python36(node))
287         source, expected = read_data('expression')
288         node = black.lib2to3_parse(source)
289         self.assertFalse(black.is_python36(node))
290         node = black.lib2to3_parse(expected)
291         self.assertFalse(black.is_python36(node))
292
293     def test_debug_visitor(self) -> None:
294         source, _ = read_data('debug_visitor.py')
295         expected, _ = read_data('debug_visitor.out')
296         out_lines = []
297         err_lines = []
298
299         def out(msg: str, **kwargs: Any) -> None:
300             out_lines.append(msg)
301
302         def err(msg: str, **kwargs: Any) -> None:
303             err_lines.append(msg)
304
305         with patch("black.out", out), patch("black.err", err):
306             black.DebugVisitor.show(source)
307         actual = '\n'.join(out_lines) + '\n'
308         log_name = ''
309         if expected != actual:
310             log_name = black.dump_to_file(*out_lines)
311         self.assertEqual(
312             expected,
313             actual,
314             f"AST print out is different. Actual version dumped to {log_name}",
315         )
316
317
318 if __name__ == '__main__':
319     unittest.main()