]> git.madduck.net Git - etc/vim.git/blob - tests/util.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Strip trailing commas in subscripts with -C (#3209)
[etc/vim.git] / tests / util.py
1 import os
2 import sys
3 import unittest
4 from contextlib import contextmanager
5 from functools import partial
6 from pathlib import Path
7 from typing import Any, Iterator, List, Optional, Tuple
8
9 import black
10 from black.debug import DebugVisitor
11 from black.mode import TargetVersion
12 from black.output import diff, err, out
13
14 PYTHON_SUFFIX = ".py"
15 ALLOWED_SUFFIXES = (PYTHON_SUFFIX, ".pyi", ".out", ".diff", ".ipynb")
16
17 THIS_DIR = Path(__file__).parent
18 DATA_DIR = THIS_DIR / "data"
19 PROJECT_ROOT = THIS_DIR.parent
20 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
21 DETERMINISTIC_HEADER = "[Deterministic header]"
22
23 PY36_VERSIONS = {
24     TargetVersion.PY36,
25     TargetVersion.PY37,
26     TargetVersion.PY38,
27     TargetVersion.PY39,
28 }
29
30 DEFAULT_MODE = black.Mode()
31 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
32 fs = partial(black.format_str, mode=DEFAULT_MODE)
33
34
35 def _assert_format_equal(expected: str, actual: str) -> None:
36     if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
37         bdv: DebugVisitor[Any]
38         out("Expected tree:", fg="green")
39         try:
40             exp_node = black.lib2to3_parse(expected)
41             bdv = DebugVisitor()
42             list(bdv.visit(exp_node))
43         except Exception as ve:
44             err(str(ve))
45         out("Actual tree:", fg="red")
46         try:
47             exp_node = black.lib2to3_parse(actual)
48             bdv = DebugVisitor()
49             list(bdv.visit(exp_node))
50         except Exception as ve:
51             err(str(ve))
52
53     if actual != expected:
54         out(diff(expected, actual, "expected", "actual"))
55
56     assert actual == expected
57
58
59 def assert_format(
60     source: str,
61     expected: str,
62     mode: black.Mode = DEFAULT_MODE,
63     *,
64     fast: bool = False,
65     minimum_version: Optional[Tuple[int, int]] = None,
66 ) -> None:
67     """Convenience function to check that Black formats as expected.
68
69     You can pass @minimum_version if you're passing code with newer syntax to guard
70     safety guards so they don't just crash with a SyntaxError. Please note this is
71     separate from TargetVerson Mode configuration.
72     """
73     actual = black.format_str(source, mode=mode)
74     _assert_format_equal(expected, actual)
75     # It's not useful to run safety checks if we're expecting no changes anyway. The
76     # assertion right above will raise if reality does actually make changes. This just
77     # avoids wasted CPU cycles.
78     if not fast and source != expected:
79         # Unfortunately the AST equivalence check relies on the built-in ast module
80         # being able to parse the code being formatted. This doesn't always work out
81         # when checking modern code on older versions.
82         if minimum_version is None or sys.version_info >= minimum_version:
83             black.assert_equivalent(source, actual)
84         black.assert_stable(source, actual, mode=mode)
85
86
87 def dump_to_stderr(*output: str) -> str:
88     return "\n" + "\n".join(output) + "\n"
89
90
91 class BlackBaseTestCase(unittest.TestCase):
92     def assertFormatEqual(self, expected: str, actual: str) -> None:
93         _assert_format_equal(expected, actual)
94
95
96 def get_base_dir(data: bool) -> Path:
97     return DATA_DIR if data else PROJECT_ROOT
98
99
100 def all_data_cases(subdir_name: str, data: bool = True) -> List[str]:
101     cases_dir = get_base_dir(data) / subdir_name
102     assert cases_dir.is_dir()
103     return [case_path.stem for case_path in cases_dir.iterdir()]
104
105
106 def get_case_path(
107     subdir_name: str, name: str, data: bool = True, suffix: str = PYTHON_SUFFIX
108 ) -> Path:
109     """Get case path from name"""
110     case_path = get_base_dir(data) / subdir_name / name
111     if not name.endswith(ALLOWED_SUFFIXES):
112         case_path = case_path.with_suffix(suffix)
113     assert case_path.is_file(), f"{case_path} is not a file."
114     return case_path
115
116
117 def read_data(subdir_name: str, name: str, data: bool = True) -> Tuple[str, str]:
118     """read_data('test_name') -> 'input', 'output'"""
119     return read_data_from_file(get_case_path(subdir_name, name, data))
120
121
122 def read_data_from_file(file_name: Path) -> Tuple[str, str]:
123     with open(file_name, "r", encoding="utf8") as test:
124         lines = test.readlines()
125     _input: List[str] = []
126     _output: List[str] = []
127     result = _input
128     for line in lines:
129         line = line.replace(EMPTY_LINE, "")
130         if line.rstrip() == "# output":
131             result = _output
132             continue
133
134         result.append(line)
135     if _input and not _output:
136         # If there's no output marker, treat the entire file as already pre-formatted.
137         _output = _input[:]
138     return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
139
140
141 def read_jupyter_notebook(subdir_name: str, name: str, data: bool = True) -> str:
142     return read_jupyter_notebook_from_file(
143         get_case_path(subdir_name, name, data, suffix=".ipynb")
144     )
145
146
147 def read_jupyter_notebook_from_file(file_name: Path) -> str:
148     with open(file_name, mode="rb") as fd:
149         content_bytes = fd.read()
150     return content_bytes.decode()
151
152
153 @contextmanager
154 def change_directory(path: Path) -> Iterator[None]:
155     """Context manager to temporarily chdir to a different directory."""
156     previous_dir = os.getcwd()
157     try:
158         os.chdir(path)
159         yield
160     finally:
161         os.chdir(previous_dir)