]> git.madduck.net Git - etc/vim.git/blob - tests/util.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix grammar for type alias support (#3949)
[etc/vim.git] / tests / util.py
1 import argparse
2 import functools
3 import os
4 import shlex
5 import sys
6 import unittest
7 from contextlib import contextmanager
8 from dataclasses import dataclass, field, replace
9 from functools import partial
10 from pathlib import Path
11 from typing import Any, Iterator, List, Optional, Tuple
12
13 import black
14 from black.const import DEFAULT_LINE_LENGTH
15 from black.debug import DebugVisitor
16 from black.mode import TargetVersion
17 from black.output import diff, err, out
18
19 from . import conftest
20
21 PYTHON_SUFFIX = ".py"
22 ALLOWED_SUFFIXES = (PYTHON_SUFFIX, ".pyi", ".out", ".diff", ".ipynb")
23
24 THIS_DIR = Path(__file__).parent
25 DATA_DIR = THIS_DIR / "data"
26 PROJECT_ROOT = THIS_DIR.parent
27 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28 DETERMINISTIC_HEADER = "[Deterministic header]"
29
30 PY36_VERSIONS = {
31     TargetVersion.PY36,
32     TargetVersion.PY37,
33     TargetVersion.PY38,
34     TargetVersion.PY39,
35 }
36
37 DEFAULT_MODE = black.Mode()
38 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
39 fs = partial(black.format_str, mode=DEFAULT_MODE)
40
41
42 @dataclass
43 class TestCaseArgs:
44     mode: black.Mode = field(default_factory=black.Mode)
45     fast: bool = False
46     minimum_version: Optional[Tuple[int, int]] = None
47
48
49 def _assert_format_equal(expected: str, actual: str) -> None:
50     if actual != expected and (conftest.PRINT_FULL_TREE or conftest.PRINT_TREE_DIFF):
51         bdv: DebugVisitor[Any]
52         actual_out: str = ""
53         expected_out: str = ""
54         if conftest.PRINT_FULL_TREE:
55             out("Expected tree:", fg="green")
56         try:
57             exp_node = black.lib2to3_parse(expected)
58             bdv = DebugVisitor(print_output=conftest.PRINT_FULL_TREE)
59             list(bdv.visit(exp_node))
60             expected_out = "\n".join(bdv.list_output)
61         except Exception as ve:
62             err(str(ve))
63         if conftest.PRINT_FULL_TREE:
64             out("Actual tree:", fg="red")
65         try:
66             exp_node = black.lib2to3_parse(actual)
67             bdv = DebugVisitor(print_output=conftest.PRINT_FULL_TREE)
68             list(bdv.visit(exp_node))
69             actual_out = "\n".join(bdv.list_output)
70         except Exception as ve:
71             err(str(ve))
72         if conftest.PRINT_TREE_DIFF:
73             out("Tree Diff:")
74             out(
75                 diff(expected_out, actual_out, "expected tree", "actual tree")
76                 or "Trees do not differ"
77             )
78
79     if actual != expected:
80         out(diff(expected, actual, "expected", "actual"))
81
82     assert actual == expected
83
84
85 class FormatFailure(Exception):
86     """Used to wrap failures when assert_format() runs in an extra mode."""
87
88
89 def assert_format(
90     source: str,
91     expected: str,
92     mode: black.Mode = DEFAULT_MODE,
93     *,
94     fast: bool = False,
95     minimum_version: Optional[Tuple[int, int]] = None,
96 ) -> None:
97     """Convenience function to check that Black formats as expected.
98
99     You can pass @minimum_version if you're passing code with newer syntax to guard
100     safety guards so they don't just crash with a SyntaxError. Please note this is
101     separate from TargetVerson Mode configuration.
102     """
103     _assert_format_inner(
104         source, expected, mode, fast=fast, minimum_version=minimum_version
105     )
106
107     # For both preview and non-preview tests, ensure that Black doesn't crash on
108     # this code, but don't pass "expected" because the precise output may differ.
109     try:
110         _assert_format_inner(
111             source,
112             None,
113             replace(mode, preview=not mode.preview),
114             fast=fast,
115             minimum_version=minimum_version,
116         )
117     except Exception as e:
118         text = "non-preview" if mode.preview else "preview"
119         raise FormatFailure(
120             f"Black crashed formatting this case in {text} mode."
121         ) from e
122     # Similarly, setting line length to 1 is a good way to catch
123     # stability bugs. But only in non-preview mode because preview mode
124     # currently has a lot of line length 1 bugs.
125     try:
126         _assert_format_inner(
127             source,
128             None,
129             replace(mode, preview=False, line_length=1),
130             fast=fast,
131             minimum_version=minimum_version,
132         )
133     except Exception as e:
134         raise FormatFailure(
135             "Black crashed formatting this case with line-length set to 1."
136         ) from e
137
138
139 def _assert_format_inner(
140     source: str,
141     expected: Optional[str] = None,
142     mode: black.Mode = DEFAULT_MODE,
143     *,
144     fast: bool = False,
145     minimum_version: Optional[Tuple[int, int]] = None,
146 ) -> None:
147     actual = black.format_str(source, mode=mode)
148     if expected is not None:
149         _assert_format_equal(expected, actual)
150     # It's not useful to run safety checks if we're expecting no changes anyway. The
151     # assertion right above will raise if reality does actually make changes. This just
152     # avoids wasted CPU cycles.
153     if not fast and source != actual:
154         # Unfortunately the AST equivalence check relies on the built-in ast module
155         # being able to parse the code being formatted. This doesn't always work out
156         # when checking modern code on older versions.
157         if minimum_version is None or sys.version_info >= minimum_version:
158             black.assert_equivalent(source, actual)
159         black.assert_stable(source, actual, mode=mode)
160
161
162 def dump_to_stderr(*output: str) -> str:
163     return "\n" + "\n".join(output) + "\n"
164
165
166 class BlackBaseTestCase(unittest.TestCase):
167     def assertFormatEqual(self, expected: str, actual: str) -> None:
168         _assert_format_equal(expected, actual)
169
170
171 def get_base_dir(data: bool) -> Path:
172     return DATA_DIR if data else PROJECT_ROOT
173
174
175 def all_data_cases(subdir_name: str, data: bool = True) -> List[str]:
176     cases_dir = get_base_dir(data) / subdir_name
177     assert cases_dir.is_dir()
178     return [case_path.stem for case_path in cases_dir.iterdir()]
179
180
181 def get_case_path(
182     subdir_name: str, name: str, data: bool = True, suffix: str = PYTHON_SUFFIX
183 ) -> Path:
184     """Get case path from name"""
185     case_path = get_base_dir(data) / subdir_name / name
186     if not name.endswith(ALLOWED_SUFFIXES):
187         case_path = case_path.with_suffix(suffix)
188     assert case_path.is_file(), f"{case_path} is not a file."
189     return case_path
190
191
192 def read_data_with_mode(
193     subdir_name: str, name: str, data: bool = True
194 ) -> Tuple[TestCaseArgs, str, str]:
195     """read_data_with_mode('test_name') -> Mode(), 'input', 'output'"""
196     return read_data_from_file(get_case_path(subdir_name, name, data))
197
198
199 def read_data(subdir_name: str, name: str, data: bool = True) -> Tuple[str, str]:
200     """read_data('test_name') -> 'input', 'output'"""
201     _, input, output = read_data_with_mode(subdir_name, name, data)
202     return input, output
203
204
205 def _parse_minimum_version(version: str) -> Tuple[int, int]:
206     major, minor = version.split(".")
207     return int(major), int(minor)
208
209
210 @functools.lru_cache()
211 def get_flags_parser() -> argparse.ArgumentParser:
212     parser = argparse.ArgumentParser()
213     parser.add_argument(
214         "--target-version",
215         action="append",
216         type=lambda val: TargetVersion[val.upper()],
217         default=(),
218     )
219     parser.add_argument("--line-length", default=DEFAULT_LINE_LENGTH, type=int)
220     parser.add_argument(
221         "--skip-string-normalization", default=False, action="store_true"
222     )
223     parser.add_argument("--pyi", default=False, action="store_true")
224     parser.add_argument("--ipynb", default=False, action="store_true")
225     parser.add_argument(
226         "--skip-magic-trailing-comma", default=False, action="store_true"
227     )
228     parser.add_argument("--preview", default=False, action="store_true")
229     parser.add_argument("--fast", default=False, action="store_true")
230     parser.add_argument(
231         "--minimum-version",
232         type=_parse_minimum_version,
233         default=None,
234         help=(
235             "Minimum version of Python where this test case is parseable. If this is"
236             " set, the test case will be run twice: once with the specified"
237             " --target-version, and once with --target-version set to exactly the"
238             " specified version. This ensures that Black's autodetection of the target"
239             " version works correctly."
240         ),
241     )
242     return parser
243
244
245 def parse_mode(flags_line: str) -> TestCaseArgs:
246     parser = get_flags_parser()
247     args = parser.parse_args(shlex.split(flags_line))
248     mode = black.Mode(
249         target_versions=set(args.target_version),
250         line_length=args.line_length,
251         string_normalization=not args.skip_string_normalization,
252         is_pyi=args.pyi,
253         is_ipynb=args.ipynb,
254         magic_trailing_comma=not args.skip_magic_trailing_comma,
255         preview=args.preview,
256     )
257     return TestCaseArgs(mode=mode, fast=args.fast, minimum_version=args.minimum_version)
258
259
260 def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]:
261     with open(file_name, "r", encoding="utf8") as test:
262         lines = test.readlines()
263     _input: List[str] = []
264     _output: List[str] = []
265     result = _input
266     mode = TestCaseArgs()
267     for line in lines:
268         if not _input and line.startswith("# flags: "):
269             mode = parse_mode(line[len("# flags: ") :])
270             continue
271         line = line.replace(EMPTY_LINE, "")
272         if line.rstrip() == "# output":
273             result = _output
274             continue
275
276         result.append(line)
277     if _input and not _output:
278         # If there's no output marker, treat the entire file as already pre-formatted.
279         _output = _input[:]
280     return mode, "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
281
282
283 def read_jupyter_notebook(subdir_name: str, name: str, data: bool = True) -> str:
284     return read_jupyter_notebook_from_file(
285         get_case_path(subdir_name, name, data, suffix=".ipynb")
286     )
287
288
289 def read_jupyter_notebook_from_file(file_name: Path) -> str:
290     with open(file_name, mode="rb") as fd:
291         content_bytes = fd.read()
292     return content_bytes.decode()
293
294
295 @contextmanager
296 def change_directory(path: Path) -> Iterator[None]:
297     """Context manager to temporarily chdir to a different directory."""
298     previous_dir = os.getcwd()
299     try:
300         os.chdir(path)
301         yield
302     finally:
303         os.chdir(previous_dir)