All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
4 from contextlib import contextmanager
5 from dataclasses import replace
6 from functools import partial
7 from pathlib import Path
8 from typing import Any, Iterator, List, Optional, Tuple
11 from black.debug import DebugVisitor
12 from black.mode import TargetVersion
13 from black.output import diff, err, out
15 from . import conftest
18 ALLOWED_SUFFIXES = (PYTHON_SUFFIX, ".pyi", ".out", ".diff", ".ipynb")
20 THIS_DIR = Path(__file__).parent
21 DATA_DIR = THIS_DIR / "data"
22 PROJECT_ROOT = THIS_DIR.parent
23 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
24 DETERMINISTIC_HEADER = "[Deterministic header]"
33 DEFAULT_MODE = black.Mode()
34 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
35 fs = partial(black.format_str, mode=DEFAULT_MODE)
38 def _assert_format_equal(expected: str, actual: str) -> None:
39 if actual != expected and (conftest.PRINT_FULL_TREE or conftest.PRINT_TREE_DIFF):
40 bdv: DebugVisitor[Any]
42 expected_out: str = ""
43 if conftest.PRINT_FULL_TREE:
44 out("Expected tree:", fg="green")
46 exp_node = black.lib2to3_parse(expected)
47 bdv = DebugVisitor(print_output=conftest.PRINT_FULL_TREE)
48 list(bdv.visit(exp_node))
49 expected_out = "\n".join(bdv.list_output)
50 except Exception as ve:
52 if conftest.PRINT_FULL_TREE:
53 out("Actual tree:", fg="red")
55 exp_node = black.lib2to3_parse(actual)
56 bdv = DebugVisitor(print_output=conftest.PRINT_FULL_TREE)
57 list(bdv.visit(exp_node))
58 actual_out = "\n".join(bdv.list_output)
59 except Exception as ve:
61 if conftest.PRINT_TREE_DIFF:
64 diff(expected_out, actual_out, "expected tree", "actual tree")
65 or "Trees do not differ"
68 if actual != expected:
69 out(diff(expected, actual, "expected", "actual"))
71 assert actual == expected
74 class FormatFailure(Exception):
75 """Used to wrap failures when assert_format() runs in an extra mode."""
81 mode: black.Mode = DEFAULT_MODE,
84 minimum_version: Optional[Tuple[int, int]] = None,
86 """Convenience function to check that Black formats as expected.
88 You can pass @minimum_version if you're passing code with newer syntax to guard
89 safety guards so they don't just crash with a SyntaxError. Please note this is
90 separate from TargetVerson Mode configuration.
93 source, expected, mode, fast=fast, minimum_version=minimum_version
96 # For both preview and non-preview tests, ensure that Black doesn't crash on
97 # this code, but don't pass "expected" because the precise output may differ.
102 replace(mode, preview=not mode.preview),
104 minimum_version=minimum_version,
106 except Exception as e:
107 text = "non-preview" if mode.preview else "preview"
109 f"Black crashed formatting this case in {text} mode."
111 # Similarly, setting line length to 1 is a good way to catch
112 # stability bugs. But only in non-preview mode because preview mode
113 # currently has a lot of line length 1 bugs.
115 _assert_format_inner(
118 replace(mode, preview=False, line_length=1),
120 minimum_version=minimum_version,
122 except Exception as e:
124 "Black crashed formatting this case with line-length set to 1."
128 def _assert_format_inner(
130 expected: Optional[str] = None,
131 mode: black.Mode = DEFAULT_MODE,
134 minimum_version: Optional[Tuple[int, int]] = None,
136 actual = black.format_str(source, mode=mode)
137 if expected is not None:
138 _assert_format_equal(expected, actual)
139 # It's not useful to run safety checks if we're expecting no changes anyway. The
140 # assertion right above will raise if reality does actually make changes. This just
141 # avoids wasted CPU cycles.
142 if not fast and source != actual:
143 # Unfortunately the AST equivalence check relies on the built-in ast module
144 # being able to parse the code being formatted. This doesn't always work out
145 # when checking modern code on older versions.
146 if minimum_version is None or sys.version_info >= minimum_version:
147 black.assert_equivalent(source, actual)
148 black.assert_stable(source, actual, mode=mode)
151 def dump_to_stderr(*output: str) -> str:
152 return "\n" + "\n".join(output) + "\n"
155 class BlackBaseTestCase(unittest.TestCase):
156 def assertFormatEqual(self, expected: str, actual: str) -> None:
157 _assert_format_equal(expected, actual)
160 def get_base_dir(data: bool) -> Path:
161 return DATA_DIR if data else PROJECT_ROOT
164 def all_data_cases(subdir_name: str, data: bool = True) -> List[str]:
165 cases_dir = get_base_dir(data) / subdir_name
166 assert cases_dir.is_dir()
167 return [case_path.stem for case_path in cases_dir.iterdir()]
171 subdir_name: str, name: str, data: bool = True, suffix: str = PYTHON_SUFFIX
173 """Get case path from name"""
174 case_path = get_base_dir(data) / subdir_name / name
175 if not name.endswith(ALLOWED_SUFFIXES):
176 case_path = case_path.with_suffix(suffix)
177 assert case_path.is_file(), f"{case_path} is not a file."
181 def read_data(subdir_name: str, name: str, data: bool = True) -> Tuple[str, str]:
182 """read_data('test_name') -> 'input', 'output'"""
183 return read_data_from_file(get_case_path(subdir_name, name, data))
186 def read_data_from_file(file_name: Path) -> Tuple[str, str]:
187 with open(file_name, "r", encoding="utf8") as test:
188 lines = test.readlines()
189 _input: List[str] = []
190 _output: List[str] = []
193 line = line.replace(EMPTY_LINE, "")
194 if line.rstrip() == "# output":
199 if _input and not _output:
200 # If there's no output marker, treat the entire file as already pre-formatted.
202 return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
205 def read_jupyter_notebook(subdir_name: str, name: str, data: bool = True) -> str:
206 return read_jupyter_notebook_from_file(
207 get_case_path(subdir_name, name, data, suffix=".ipynb")
211 def read_jupyter_notebook_from_file(file_name: Path) -> str:
212 with open(file_name, mode="rb") as fd:
213 content_bytes = fd.read()
214 return content_bytes.decode()
218 def change_directory(path: Path) -> Iterator[None]:
219 """Context manager to temporarily chdir to a different directory."""
220 previous_dir = os.getcwd()
225 os.chdir(previous_dir)