]> git.madduck.net Git - etc/vim.git/blob - tests/test_black.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

7dba611141c099e9054421f55bd371ac242fced3
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 from functools import partial
3 from io import StringIO
4 from pathlib import Path
5 import sys
6 from typing import Any, List, Tuple
7 import unittest
8 from unittest.mock import patch
9
10 from click import unstyle
11
12 import black
13
14 ll = 88
15 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
16 fs = partial(black.format_str, line_length=ll)
17 THIS_FILE = Path(__file__)
18 THIS_DIR = THIS_FILE.parent
19
20
21 def dump_to_stderr(*output: str) -> str:
22     return '\n' + '\n'.join(output) + '\n'
23
24
25 def read_data(name: str) -> Tuple[str, str]:
26     """read_data('test_name') -> 'input', 'output'"""
27     if not name.endswith('.py'):
28         name += '.py'
29     _input: List[str] = []
30     _output: List[str] = []
31     with open(THIS_DIR / name, 'r', encoding='utf8') as test:
32         lines = test.readlines()
33     result = _input
34     for line in lines:
35         if line.rstrip() == '# output':
36             result = _output
37             continue
38
39         result.append(line)
40     if _input and not _output:
41         # If there's no output marker, treat the entire file as already pre-formatted.
42         _output = _input[:]
43     return ''.join(_input).strip() + '\n', ''.join(_output).strip() + '\n'
44
45
46 class BlackTestCase(unittest.TestCase):
47     maxDiff = None
48
49     def assertFormatEqual(self, expected: str, actual: str) -> None:
50         if actual != expected:
51             bdv: black.DebugVisitor[Any]
52             black.out('Expected tree:', fg='green')
53             try:
54                 exp_node = black.lib2to3_parse(expected)
55                 bdv = black.DebugVisitor()
56                 list(bdv.visit(exp_node))
57             except Exception as ve:
58                 black.err(str(ve))
59             black.out('Actual tree:', fg='red')
60             try:
61                 exp_node = black.lib2to3_parse(actual)
62                 bdv = black.DebugVisitor()
63                 list(bdv.visit(exp_node))
64             except Exception as ve:
65                 black.err(str(ve))
66         self.assertEqual(expected, actual)
67
68     @patch("black.dump_to_file", dump_to_stderr)
69     def test_self(self) -> None:
70         source, expected = read_data('test_black')
71         actual = fs(source)
72         self.assertFormatEqual(expected, actual)
73         black.assert_equivalent(source, actual)
74         black.assert_stable(source, actual, line_length=ll)
75         self.assertFalse(ff(THIS_FILE))
76
77     @patch("black.dump_to_file", dump_to_stderr)
78     def test_black(self) -> None:
79         source, expected = read_data('../black')
80         actual = fs(source)
81         self.assertFormatEqual(expected, actual)
82         black.assert_equivalent(source, actual)
83         black.assert_stable(source, actual, line_length=ll)
84         self.assertFalse(ff(THIS_DIR / '..' / 'black.py'))
85
86     def test_piping(self) -> None:
87         source, expected = read_data('../black')
88         hold_stdin, hold_stdout = sys.stdin, sys.stdout
89         try:
90             sys.stdin, sys.stdout = StringIO(source), StringIO()
91             sys.stdin.name = '<stdin>'
92             black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True)
93             sys.stdout.seek(0)
94             actual = sys.stdout.read()
95         finally:
96             sys.stdin, sys.stdout = hold_stdin, hold_stdout
97         self.assertFormatEqual(expected, actual)
98         black.assert_equivalent(source, actual)
99         black.assert_stable(source, actual, line_length=ll)
100
101     @patch("black.dump_to_file", dump_to_stderr)
102     def test_setup(self) -> None:
103         source, expected = read_data('../setup')
104         actual = fs(source)
105         self.assertFormatEqual(expected, actual)
106         black.assert_equivalent(source, actual)
107         black.assert_stable(source, actual, line_length=ll)
108         self.assertFalse(ff(THIS_DIR / '..' / 'setup.py'))
109
110     @patch("black.dump_to_file", dump_to_stderr)
111     def test_function(self) -> None:
112         source, expected = read_data('function')
113         actual = fs(source)
114         self.assertFormatEqual(expected, actual)
115         black.assert_equivalent(source, actual)
116         black.assert_stable(source, actual, line_length=ll)
117
118     @patch("black.dump_to_file", dump_to_stderr)
119     def test_expression(self) -> None:
120         source, expected = read_data('expression')
121         actual = fs(source)
122         self.assertFormatEqual(expected, actual)
123         black.assert_equivalent(source, actual)
124         black.assert_stable(source, actual, line_length=ll)
125
126     @patch("black.dump_to_file", dump_to_stderr)
127     def test_fstring(self) -> None:
128         source, expected = read_data('fstring')
129         actual = fs(source)
130         self.assertFormatEqual(expected, actual)
131         black.assert_equivalent(source, actual)
132         black.assert_stable(source, actual, line_length=ll)
133
134     @patch("black.dump_to_file", dump_to_stderr)
135     def test_comments(self) -> None:
136         source, expected = read_data('comments')
137         actual = fs(source)
138         self.assertFormatEqual(expected, actual)
139         black.assert_equivalent(source, actual)
140         black.assert_stable(source, actual, line_length=ll)
141
142     @patch("black.dump_to_file", dump_to_stderr)
143     def test_comments2(self) -> None:
144         source, expected = read_data('comments2')
145         actual = fs(source)
146         self.assertFormatEqual(expected, actual)
147         black.assert_equivalent(source, actual)
148         black.assert_stable(source, actual, line_length=ll)
149
150     @patch("black.dump_to_file", dump_to_stderr)
151     def test_cantfit(self) -> None:
152         source, expected = read_data('cantfit')
153         actual = fs(source)
154         self.assertFormatEqual(expected, actual)
155         black.assert_equivalent(source, actual)
156         black.assert_stable(source, actual, line_length=ll)
157
158     @patch("black.dump_to_file", dump_to_stderr)
159     def test_import_spacing(self) -> None:
160         source, expected = read_data('import_spacing')
161         actual = fs(source)
162         self.assertFormatEqual(expected, actual)
163         black.assert_equivalent(source, actual)
164         black.assert_stable(source, actual, line_length=ll)
165
166     @patch("black.dump_to_file", dump_to_stderr)
167     def test_composition(self) -> None:
168         source, expected = read_data('composition')
169         actual = fs(source)
170         self.assertFormatEqual(expected, actual)
171         black.assert_equivalent(source, actual)
172         black.assert_stable(source, actual, line_length=ll)
173
174     @patch("black.dump_to_file", dump_to_stderr)
175     def test_empty_lines(self) -> None:
176         source, expected = read_data('empty_lines')
177         actual = fs(source)
178         self.assertFormatEqual(expected, actual)
179         black.assert_equivalent(source, actual)
180         black.assert_stable(source, actual, line_length=ll)
181
182     def test_report(self) -> None:
183         report = black.Report()
184         out_lines = []
185         err_lines = []
186
187         def out(msg: str, **kwargs: Any) -> None:
188             out_lines.append(msg)
189
190         def err(msg: str, **kwargs: Any) -> None:
191             err_lines.append(msg)
192
193         with patch("black.out", out), patch("black.err", err):
194             report.done(Path('f1'), changed=False)
195             self.assertEqual(len(out_lines), 1)
196             self.assertEqual(len(err_lines), 0)
197             self.assertEqual(out_lines[-1], 'f1 already well formatted, good job.')
198             self.assertEqual(unstyle(str(report)), '1 file left unchanged.')
199             self.assertEqual(report.return_code, 0)
200             report.done(Path('f2'), changed=True)
201             self.assertEqual(len(out_lines), 2)
202             self.assertEqual(len(err_lines), 0)
203             self.assertEqual(out_lines[-1], 'reformatted f2')
204             self.assertEqual(
205                 unstyle(str(report)), '1 file reformatted, 1 file left unchanged.'
206             )
207             self.assertEqual(report.return_code, 1)
208             report.failed(Path('e1'), 'boom')
209             self.assertEqual(len(out_lines), 2)
210             self.assertEqual(len(err_lines), 1)
211             self.assertEqual(err_lines[-1], 'error: cannot format e1: boom')
212             self.assertEqual(
213                 unstyle(str(report)),
214                 '1 file reformatted, 1 file left unchanged, '
215                 '1 file failed to reformat.',
216             )
217             self.assertEqual(report.return_code, 123)
218             report.done(Path('f3'), changed=True)
219             self.assertEqual(len(out_lines), 3)
220             self.assertEqual(len(err_lines), 1)
221             self.assertEqual(out_lines[-1], 'reformatted f3')
222             self.assertEqual(
223                 unstyle(str(report)),
224                 '2 files reformatted, 1 file left unchanged, '
225                 '1 file failed to reformat.',
226             )
227             self.assertEqual(report.return_code, 123)
228             report.failed(Path('e2'), 'boom')
229             self.assertEqual(len(out_lines), 3)
230             self.assertEqual(len(err_lines), 2)
231             self.assertEqual(err_lines[-1], 'error: cannot format e2: boom')
232             self.assertEqual(
233                 unstyle(str(report)),
234                 '2 files reformatted, 1 file left unchanged, '
235                 '2 files failed to reformat.',
236             )
237             self.assertEqual(report.return_code, 123)
238             report.done(Path('f4'), changed=False)
239             self.assertEqual(len(out_lines), 4)
240             self.assertEqual(len(err_lines), 2)
241             self.assertEqual(out_lines[-1], 'f4 already well formatted, good job.')
242             self.assertEqual(
243                 unstyle(str(report)),
244                 '2 files reformatted, 2 files left unchanged, '
245                 '2 files failed to reformat.',
246             )
247             self.assertEqual(report.return_code, 123)
248
249     def test_is_python36(self) -> None:
250         node = black.lib2to3_parse("def f(*, arg): ...\n")
251         self.assertFalse(black.is_python36(node))
252         node = black.lib2to3_parse("def f(*, arg,): ...\n")
253         self.assertTrue(black.is_python36(node))
254         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
255         self.assertTrue(black.is_python36(node))
256         source, expected = read_data('function')
257         node = black.lib2to3_parse(source)
258         self.assertTrue(black.is_python36(node))
259         node = black.lib2to3_parse(expected)
260         self.assertTrue(black.is_python36(node))
261         source, expected = read_data('expression')
262         node = black.lib2to3_parse(source)
263         self.assertFalse(black.is_python36(node))
264         node = black.lib2to3_parse(expected)
265         self.assertFalse(black.is_python36(node))
266
267
268 if __name__ == '__main__':
269     unittest.main()