]> git.madduck.net Git - etc/vim.git/blob - tests/util.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

add support for printing the diff of AST trees when running tests (#3902)
[etc/vim.git] / tests / util.py
1 import os
2 import sys
3 import unittest
4 from contextlib import contextmanager
5 from dataclasses import replace
6 from functools import partial
7 from pathlib import Path
8 from typing import Any, Iterator, List, Optional, Tuple
9
10 import black
11 from black.debug import DebugVisitor
12 from black.mode import TargetVersion
13 from black.output import diff, err, out
14
15 from . import conftest
16
17 PYTHON_SUFFIX = ".py"
18 ALLOWED_SUFFIXES = (PYTHON_SUFFIX, ".pyi", ".out", ".diff", ".ipynb")
19
20 THIS_DIR = Path(__file__).parent
21 DATA_DIR = THIS_DIR / "data"
22 PROJECT_ROOT = THIS_DIR.parent
23 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
24 DETERMINISTIC_HEADER = "[Deterministic header]"
25
26 PY36_VERSIONS = {
27     TargetVersion.PY36,
28     TargetVersion.PY37,
29     TargetVersion.PY38,
30     TargetVersion.PY39,
31 }
32
33 DEFAULT_MODE = black.Mode()
34 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
35 fs = partial(black.format_str, mode=DEFAULT_MODE)
36
37
38 def _assert_format_equal(expected: str, actual: str) -> None:
39     if actual != expected and (conftest.PRINT_FULL_TREE or conftest.PRINT_TREE_DIFF):
40         bdv: DebugVisitor[Any]
41         actual_out: str = ""
42         expected_out: str = ""
43         if conftest.PRINT_FULL_TREE:
44             out("Expected tree:", fg="green")
45         try:
46             exp_node = black.lib2to3_parse(expected)
47             bdv = DebugVisitor(print_output=conftest.PRINT_FULL_TREE)
48             list(bdv.visit(exp_node))
49             expected_out = "\n".join(bdv.list_output)
50         except Exception as ve:
51             err(str(ve))
52         if conftest.PRINT_FULL_TREE:
53             out("Actual tree:", fg="red")
54         try:
55             exp_node = black.lib2to3_parse(actual)
56             bdv = DebugVisitor(print_output=conftest.PRINT_FULL_TREE)
57             list(bdv.visit(exp_node))
58             actual_out = "\n".join(bdv.list_output)
59         except Exception as ve:
60             err(str(ve))
61         if conftest.PRINT_TREE_DIFF:
62             out("Tree Diff:")
63             out(
64                 diff(expected_out, actual_out, "expected tree", "actual tree")
65                 or "Trees do not differ"
66             )
67
68     if actual != expected:
69         out(diff(expected, actual, "expected", "actual"))
70
71     assert actual == expected
72
73
74 class FormatFailure(Exception):
75     """Used to wrap failures when assert_format() runs in an extra mode."""
76
77
78 def assert_format(
79     source: str,
80     expected: str,
81     mode: black.Mode = DEFAULT_MODE,
82     *,
83     fast: bool = False,
84     minimum_version: Optional[Tuple[int, int]] = None,
85 ) -> None:
86     """Convenience function to check that Black formats as expected.
87
88     You can pass @minimum_version if you're passing code with newer syntax to guard
89     safety guards so they don't just crash with a SyntaxError. Please note this is
90     separate from TargetVerson Mode configuration.
91     """
92     _assert_format_inner(
93         source, expected, mode, fast=fast, minimum_version=minimum_version
94     )
95
96     # For both preview and non-preview tests, ensure that Black doesn't crash on
97     # this code, but don't pass "expected" because the precise output may differ.
98     try:
99         _assert_format_inner(
100             source,
101             None,
102             replace(mode, preview=not mode.preview),
103             fast=fast,
104             minimum_version=minimum_version,
105         )
106     except Exception as e:
107         text = "non-preview" if mode.preview else "preview"
108         raise FormatFailure(
109             f"Black crashed formatting this case in {text} mode."
110         ) from e
111     # Similarly, setting line length to 1 is a good way to catch
112     # stability bugs. But only in non-preview mode because preview mode
113     # currently has a lot of line length 1 bugs.
114     try:
115         _assert_format_inner(
116             source,
117             None,
118             replace(mode, preview=False, line_length=1),
119             fast=fast,
120             minimum_version=minimum_version,
121         )
122     except Exception as e:
123         raise FormatFailure(
124             "Black crashed formatting this case with line-length set to 1."
125         ) from e
126
127
128 def _assert_format_inner(
129     source: str,
130     expected: Optional[str] = None,
131     mode: black.Mode = DEFAULT_MODE,
132     *,
133     fast: bool = False,
134     minimum_version: Optional[Tuple[int, int]] = None,
135 ) -> None:
136     actual = black.format_str(source, mode=mode)
137     if expected is not None:
138         _assert_format_equal(expected, actual)
139     # It's not useful to run safety checks if we're expecting no changes anyway. The
140     # assertion right above will raise if reality does actually make changes. This just
141     # avoids wasted CPU cycles.
142     if not fast and source != actual:
143         # Unfortunately the AST equivalence check relies on the built-in ast module
144         # being able to parse the code being formatted. This doesn't always work out
145         # when checking modern code on older versions.
146         if minimum_version is None or sys.version_info >= minimum_version:
147             black.assert_equivalent(source, actual)
148         black.assert_stable(source, actual, mode=mode)
149
150
151 def dump_to_stderr(*output: str) -> str:
152     return "\n" + "\n".join(output) + "\n"
153
154
155 class BlackBaseTestCase(unittest.TestCase):
156     def assertFormatEqual(self, expected: str, actual: str) -> None:
157         _assert_format_equal(expected, actual)
158
159
160 def get_base_dir(data: bool) -> Path:
161     return DATA_DIR if data else PROJECT_ROOT
162
163
164 def all_data_cases(subdir_name: str, data: bool = True) -> List[str]:
165     cases_dir = get_base_dir(data) / subdir_name
166     assert cases_dir.is_dir()
167     return [case_path.stem for case_path in cases_dir.iterdir()]
168
169
170 def get_case_path(
171     subdir_name: str, name: str, data: bool = True, suffix: str = PYTHON_SUFFIX
172 ) -> Path:
173     """Get case path from name"""
174     case_path = get_base_dir(data) / subdir_name / name
175     if not name.endswith(ALLOWED_SUFFIXES):
176         case_path = case_path.with_suffix(suffix)
177     assert case_path.is_file(), f"{case_path} is not a file."
178     return case_path
179
180
181 def read_data(subdir_name: str, name: str, data: bool = True) -> Tuple[str, str]:
182     """read_data('test_name') -> 'input', 'output'"""
183     return read_data_from_file(get_case_path(subdir_name, name, data))
184
185
186 def read_data_from_file(file_name: Path) -> Tuple[str, str]:
187     with open(file_name, "r", encoding="utf8") as test:
188         lines = test.readlines()
189     _input: List[str] = []
190     _output: List[str] = []
191     result = _input
192     for line in lines:
193         line = line.replace(EMPTY_LINE, "")
194         if line.rstrip() == "# output":
195             result = _output
196             continue
197
198         result.append(line)
199     if _input and not _output:
200         # If there's no output marker, treat the entire file as already pre-formatted.
201         _output = _input[:]
202     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
203
204
205 def read_jupyter_notebook(subdir_name: str, name: str, data: bool = True) -> str:
206     return read_jupyter_notebook_from_file(
207         get_case_path(subdir_name, name, data, suffix=".ipynb")
208     )
209
210
211 def read_jupyter_notebook_from_file(file_name: Path) -> str:
212     with open(file_name, mode="rb") as fd:
213         content_bytes = fd.read()
214     return content_bytes.decode()
215
216
217 @contextmanager
218 def change_directory(path: Path) -> Iterator[None]:
219     """Context manager to temporarily chdir to a different directory."""
220     previous_dir = os.getcwd()
221     try:
222         os.chdir(path)
223         yield
224     finally:
225         os.chdir(previous_dir)