madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Mention joslarson.black-vscode
[etc/vim.git] / tests / test_black.py
1 #!/usr/bin/env python3
2 from functools import partial
3 from io import StringIO
4 import os
5 from pathlib import Path
6 import sys
7 from typing import Any, List, Tuple
8 import unittest
9 from unittest.mock import patch
10
11 from click import unstyle
12
13 import black
14
15 ll = 88
16 ff = partial(black.format_file_in_place, line_length=ll, fast=True)
17 fs = partial(black.format_str, line_length=ll)
18 THIS_FILE = Path(__file__)
19 THIS_DIR = THIS_FILE.parent
20
21
22 def dump_to_stderr(*output: str) -> str:
23     return '\n' + '\n'.join(output) + '\n'
24
25
26 def read_data(name: str) -> Tuple[str, str]:
27     """read_data('test_name') -> 'input', 'output'"""
28     if not name.endswith('.py'):
29         name += '.py'
30     _input: List[str] = []
31     _output: List[str] = []
32     with open(THIS_DIR / name, 'r', encoding='utf8') as test:
33         lines = test.readlines()
34     result = _input
35     for line in lines:
36         if line.rstrip() == '# output':
37             result = _output
38             continue
39
40         result.append(line)
41     if _input and not _output:
42         # If there's no output marker, treat the entire file as already pre-formatted.
43         _output = _input[:]
44     return ''.join(_input).strip() + '\n', ''.join(_output).strip() + '\n'
45
46
47 class BlackTestCase(unittest.TestCase):
48     maxDiff = None
49
50     def assertFormatEqual(self, expected: str, actual: str) -> None:
51         if actual != expected and not os.environ.get('SKIP_AST_PRINT'):
52             bdv: black.DebugVisitor[Any]
53             black.out('Expected tree:', fg='green')
54             try:
55                 exp_node = black.lib2to3_parse(expected)
56                 bdv = black.DebugVisitor()
57                 list(bdv.visit(exp_node))
58             except Exception as ve:
59                 black.err(str(ve))
60             black.out('Actual tree:', fg='red')
61             try:
62                 exp_node = black.lib2to3_parse(actual)
63                 bdv = black.DebugVisitor()
64                 list(bdv.visit(exp_node))
65             except Exception as ve:
66                 black.err(str(ve))
67         self.assertEqual(expected, actual)
68
69     @patch("black.dump_to_file", dump_to_stderr)
70     def test_self(self) -> None:
71         source, expected = read_data('test_black')
72         actual = fs(source)
73         self.assertFormatEqual(expected, actual)
74         black.assert_equivalent(source, actual)
75         black.assert_stable(source, actual, line_length=ll)
76         self.assertFalse(ff(THIS_FILE))
77
78     @patch("black.dump_to_file", dump_to_stderr)
79     def test_black(self) -> None:
80         source, expected = read_data('../black')
81         actual = fs(source)
82         self.assertFormatEqual(expected, actual)
83         black.assert_equivalent(source, actual)
84         black.assert_stable(source, actual, line_length=ll)
85         self.assertFalse(ff(THIS_DIR / '..' / 'black.py'))
86
87     def test_piping(self) -> None:
88         source, expected = read_data('../black')
89         hold_stdin, hold_stdout = sys.stdin, sys.stdout
90         try:
91             sys.stdin, sys.stdout = StringIO(source), StringIO()
92             sys.stdin.name = '<stdin>'
93             black.format_stdin_to_stdout(line_length=ll, fast=True, write_back=True)
94             sys.stdout.seek(0)
95             actual = sys.stdout.read()
96         finally:
97             sys.stdin, sys.stdout = hold_stdin, hold_stdout
98         self.assertFormatEqual(expected, actual)
99         black.assert_equivalent(source, actual)
100         black.assert_stable(source, actual, line_length=ll)
101
102     @patch("black.dump_to_file", dump_to_stderr)
103     def test_setup(self) -> None:
104         source, expected = read_data('../setup')
105         actual = fs(source)
106         self.assertFormatEqual(expected, actual)
107         black.assert_equivalent(source, actual)
108         black.assert_stable(source, actual, line_length=ll)
109         self.assertFalse(ff(THIS_DIR / '..' / 'setup.py'))
110
111     @patch("black.dump_to_file", dump_to_stderr)
112     def test_function(self) -> None:
113         source, expected = read_data('function')
114         actual = fs(source)
115         self.assertFormatEqual(expected, actual)
116         black.assert_equivalent(source, actual)
117         black.assert_stable(source, actual, line_length=ll)
118
119     @patch("black.dump_to_file", dump_to_stderr)
120     def test_expression(self) -> None:
121         source, expected = read_data('expression')
122         actual = fs(source)
123         self.assertFormatEqual(expected, actual)
124         black.assert_equivalent(source, actual)
125         black.assert_stable(source, actual, line_length=ll)
126
127     @patch("black.dump_to_file", dump_to_stderr)
128     def test_fstring(self) -> None:
129         source, expected = read_data('fstring')
130         actual = fs(source)
131         self.assertFormatEqual(expected, actual)
132         black.assert_equivalent(source, actual)
133         black.assert_stable(source, actual, line_length=ll)
134
135     @patch("black.dump_to_file", dump_to_stderr)
136     def test_comments(self) -> None:
137         source, expected = read_data('comments')
138         actual = fs(source)
139         self.assertFormatEqual(expected, actual)
140         black.assert_equivalent(source, actual)
141         black.assert_stable(source, actual, line_length=ll)
142
143     @patch("black.dump_to_file", dump_to_stderr)
144     def test_comments2(self) -> None:
145         source, expected = read_data('comments2')
146         actual = fs(source)
147         self.assertFormatEqual(expected, actual)
148         black.assert_equivalent(source, actual)
149         black.assert_stable(source, actual, line_length=ll)
150
151     @patch("black.dump_to_file", dump_to_stderr)
152     def test_cantfit(self) -> None:
153         source, expected = read_data('cantfit')
154         actual = fs(source)
155         self.assertFormatEqual(expected, actual)
156         black.assert_equivalent(source, actual)
157         black.assert_stable(source, actual, line_length=ll)
158
159     @patch("black.dump_to_file", dump_to_stderr)
160     def test_import_spacing(self) -> None:
161         source, expected = read_data('import_spacing')
162         actual = fs(source)
163         self.assertFormatEqual(expected, actual)
164         black.assert_equivalent(source, actual)
165         black.assert_stable(source, actual, line_length=ll)
166
167     @patch("black.dump_to_file", dump_to_stderr)
168     def test_composition(self) -> None:
169         source, expected = read_data('composition')
170         actual = fs(source)
171         self.assertFormatEqual(expected, actual)
172         black.assert_equivalent(source, actual)
173         black.assert_stable(source, actual, line_length=ll)
174
175     @patch("black.dump_to_file", dump_to_stderr)
176     def test_empty_lines(self) -> None:
177         source, expected = read_data('empty_lines')
178         actual = fs(source)
179         self.assertFormatEqual(expected, actual)
180         black.assert_equivalent(source, actual)
181         black.assert_stable(source, actual, line_length=ll)
182
183     def test_report(self) -> None:
184         report = black.Report()
185         out_lines = []
186         err_lines = []
187
188         def out(msg: str, **kwargs: Any) -> None:
189             out_lines.append(msg)
190
191         def err(msg: str, **kwargs: Any) -> None:
192             err_lines.append(msg)
193
194         with patch("black.out", out), patch("black.err", err):
195             report.done(Path('f1'), changed=False)
196             self.assertEqual(len(out_lines), 1)
197             self.assertEqual(len(err_lines), 0)
198             self.assertEqual(out_lines[-1], 'f1 already well formatted, good job.')
199             self.assertEqual(unstyle(str(report)), '1 file left unchanged.')
200             self.assertEqual(report.return_code, 0)
201             report.done(Path('f2'), changed=True)
202             self.assertEqual(len(out_lines), 2)
203             self.assertEqual(len(err_lines), 0)
204             self.assertEqual(out_lines[-1], 'reformatted f2')
205             self.assertEqual(
206                 unstyle(str(report)), '1 file reformatted, 1 file left unchanged.'
207             )
208             self.assertEqual(report.return_code, 1)
209             report.failed(Path('e1'), 'boom')
210             self.assertEqual(len(out_lines), 2)
211             self.assertEqual(len(err_lines), 1)
212             self.assertEqual(err_lines[-1], 'error: cannot format e1: boom')
213             self.assertEqual(
214                 unstyle(str(report)),
215                 '1 file reformatted, 1 file left unchanged, '
216                 '1 file failed to reformat.',
217             )
218             self.assertEqual(report.return_code, 123)
219             report.done(Path('f3'), changed=True)
220             self.assertEqual(len(out_lines), 3)
221             self.assertEqual(len(err_lines), 1)
222             self.assertEqual(out_lines[-1], 'reformatted f3')
223             self.assertEqual(
224                 unstyle(str(report)),
225                 '2 files reformatted, 1 file left unchanged, '
226                 '1 file failed to reformat.',
227             )
228             self.assertEqual(report.return_code, 123)
229             report.failed(Path('e2'), 'boom')
230             self.assertEqual(len(out_lines), 3)
231             self.assertEqual(len(err_lines), 2)
232             self.assertEqual(err_lines[-1], 'error: cannot format e2: boom')
233             self.assertEqual(
234                 unstyle(str(report)),
235                 '2 files reformatted, 1 file left unchanged, '
236                 '2 files failed to reformat.',
237             )
238             self.assertEqual(report.return_code, 123)
239             report.done(Path('f4'), changed=False)
240             self.assertEqual(len(out_lines), 4)
241             self.assertEqual(len(err_lines), 2)
242             self.assertEqual(out_lines[-1], 'f4 already well formatted, good job.')
243             self.assertEqual(
244                 unstyle(str(report)),
245                 '2 files reformatted, 2 files left unchanged, '
246                 '2 files failed to reformat.',
247             )
248             self.assertEqual(report.return_code, 123)
249
250     def test_is_python36(self) -> None:
251         node = black.lib2to3_parse("def f(*, arg): ...\n")
252         self.assertFalse(black.is_python36(node))
253         node = black.lib2to3_parse("def f(*, arg,): ...\n")
254         self.assertTrue(black.is_python36(node))
255         node = black.lib2to3_parse("def f(*, arg): f'string'\n")
256         self.assertTrue(black.is_python36(node))
257         source, expected = read_data('function')
258         node = black.lib2to3_parse(source)
259         self.assertTrue(black.is_python36(node))
260         node = black.lib2to3_parse(expected)
261         self.assertTrue(black.is_python36(node))
262         source, expected = read_data('expression')
263         node = black.lib2to3_parse(source)
264         self.assertFalse(black.is_python36(node))
265         node = black.lib2to3_parse(expected)
266         self.assertFalse(black.is_python36(node))
267
268
269 if __name__ == '__main__':
270     unittest.main()