]> git.madduck.net Git - etc/vim.git/blob - tests/util.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix an example for 'Improved parentheses management' in the (future of the) Black...
[etc/vim.git] / tests / util.py
1 import os
2 import sys
3 import unittest
4 from contextlib import contextmanager
5 from dataclasses import replace
6 from functools import partial
7 from pathlib import Path
8 from typing import Any, Iterator, List, Optional, Tuple
9
10 import black
11 from black.debug import DebugVisitor
12 from black.mode import TargetVersion
13 from black.output import diff, err, out
14
15 PYTHON_SUFFIX = ".py"
16 ALLOWED_SUFFIXES = (PYTHON_SUFFIX, ".pyi", ".out", ".diff", ".ipynb")
17
18 THIS_DIR = Path(__file__).parent
19 DATA_DIR = THIS_DIR / "data"
20 PROJECT_ROOT = THIS_DIR.parent
21 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
22 DETERMINISTIC_HEADER = "[Deterministic header]"
23
24 PY36_VERSIONS = {
25     TargetVersion.PY36,
26     TargetVersion.PY37,
27     TargetVersion.PY38,
28     TargetVersion.PY39,
29 }
30
31 DEFAULT_MODE = black.Mode()
32 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
33 fs = partial(black.format_str, mode=DEFAULT_MODE)
34
35
36 def _assert_format_equal(expected: str, actual: str) -> None:
37     if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
38         bdv: DebugVisitor[Any]
39         out("Expected tree:", fg="green")
40         try:
41             exp_node = black.lib2to3_parse(expected)
42             bdv = DebugVisitor()
43             list(bdv.visit(exp_node))
44         except Exception as ve:
45             err(str(ve))
46         out("Actual tree:", fg="red")
47         try:
48             exp_node = black.lib2to3_parse(actual)
49             bdv = DebugVisitor()
50             list(bdv.visit(exp_node))
51         except Exception as ve:
52             err(str(ve))
53
54     if actual != expected:
55         out(diff(expected, actual, "expected", "actual"))
56
57     assert actual == expected
58
59
60 class FormatFailure(Exception):
61     """Used to wrap failures when assert_format() runs in an extra mode."""
62
63
64 def assert_format(
65     source: str,
66     expected: str,
67     mode: black.Mode = DEFAULT_MODE,
68     *,
69     fast: bool = False,
70     minimum_version: Optional[Tuple[int, int]] = None,
71 ) -> None:
72     """Convenience function to check that Black formats as expected.
73
74     You can pass @minimum_version if you're passing code with newer syntax to guard
75     safety guards so they don't just crash with a SyntaxError. Please note this is
76     separate from TargetVerson Mode configuration.
77     """
78     _assert_format_inner(
79         source, expected, mode, fast=fast, minimum_version=minimum_version
80     )
81
82     # For both preview and non-preview tests, ensure that Black doesn't crash on
83     # this code, but don't pass "expected" because the precise output may differ.
84     try:
85         _assert_format_inner(
86             source,
87             None,
88             replace(mode, preview=not mode.preview),
89             fast=fast,
90             minimum_version=minimum_version,
91         )
92     except Exception as e:
93         text = "non-preview" if mode.preview else "preview"
94         raise FormatFailure(
95             f"Black crashed formatting this case in {text} mode."
96         ) from e
97     # Similarly, setting line length to 1 is a good way to catch
98     # stability bugs. But only in non-preview mode because preview mode
99     # currently has a lot of line length 1 bugs.
100     try:
101         _assert_format_inner(
102             source,
103             None,
104             replace(mode, preview=False, line_length=1),
105             fast=fast,
106             minimum_version=minimum_version,
107         )
108     except Exception as e:
109         raise FormatFailure(
110             "Black crashed formatting this case with line-length set to 1."
111         ) from e
112
113
114 def _assert_format_inner(
115     source: str,
116     expected: Optional[str] = None,
117     mode: black.Mode = DEFAULT_MODE,
118     *,
119     fast: bool = False,
120     minimum_version: Optional[Tuple[int, int]] = None,
121 ) -> None:
122     actual = black.format_str(source, mode=mode)
123     if expected is not None:
124         _assert_format_equal(expected, actual)
125     # It's not useful to run safety checks if we're expecting no changes anyway. The
126     # assertion right above will raise if reality does actually make changes. This just
127     # avoids wasted CPU cycles.
128     if not fast and source != actual:
129         # Unfortunately the AST equivalence check relies on the built-in ast module
130         # being able to parse the code being formatted. This doesn't always work out
131         # when checking modern code on older versions.
132         if minimum_version is None or sys.version_info >= minimum_version:
133             black.assert_equivalent(source, actual)
134         black.assert_stable(source, actual, mode=mode)
135
136
137 def dump_to_stderr(*output: str) -> str:
138     return "\n" + "\n".join(output) + "\n"
139
140
141 class BlackBaseTestCase(unittest.TestCase):
142     def assertFormatEqual(self, expected: str, actual: str) -> None:
143         _assert_format_equal(expected, actual)
144
145
146 def get_base_dir(data: bool) -> Path:
147     return DATA_DIR if data else PROJECT_ROOT
148
149
150 def all_data_cases(subdir_name: str, data: bool = True) -> List[str]:
151     cases_dir = get_base_dir(data) / subdir_name
152     assert cases_dir.is_dir()
153     return [case_path.stem for case_path in cases_dir.iterdir()]
154
155
156 def get_case_path(
157     subdir_name: str, name: str, data: bool = True, suffix: str = PYTHON_SUFFIX
158 ) -> Path:
159     """Get case path from name"""
160     case_path = get_base_dir(data) / subdir_name / name
161     if not name.endswith(ALLOWED_SUFFIXES):
162         case_path = case_path.with_suffix(suffix)
163     assert case_path.is_file(), f"{case_path} is not a file."
164     return case_path
165
166
167 def read_data(subdir_name: str, name: str, data: bool = True) -> Tuple[str, str]:
168     """read_data('test_name') -> 'input', 'output'"""
169     return read_data_from_file(get_case_path(subdir_name, name, data))
170
171
172 def read_data_from_file(file_name: Path) -> Tuple[str, str]:
173     with open(file_name, "r", encoding="utf8") as test:
174         lines = test.readlines()
175     _input: List[str] = []
176     _output: List[str] = []
177     result = _input
178     for line in lines:
179         line = line.replace(EMPTY_LINE, "")
180         if line.rstrip() == "# output":
181             result = _output
182             continue
183
184         result.append(line)
185     if _input and not _output:
186         # If there's no output marker, treat the entire file as already pre-formatted.
187         _output = _input[:]
188     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
189
190
191 def read_jupyter_notebook(subdir_name: str, name: str, data: bool = True) -> str:
192     return read_jupyter_notebook_from_file(
193         get_case_path(subdir_name, name, data, suffix=".ipynb")
194     )
195
196
197 def read_jupyter_notebook_from_file(file_name: Path) -> str:
198     with open(file_name, mode="rb") as fd:
199         content_bytes = fd.read()
200     return content_bytes.decode()
201
202
203 @contextmanager
204 def change_directory(path: Path) -> Iterator[None]:
205     """Context manager to temporarily chdir to a different directory."""
206     previous_dir = os.getcwd()
207     try:
208         os.chdir(path)
209         yield
210     finally:
211         os.chdir(previous_dir)