All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
7 from contextlib import contextmanager
8 from dataclasses import dataclass, field, replace
9 from functools import partial
10 from pathlib import Path
11 from typing import Any, Iterator, List, Optional, Tuple
14 from black.const import DEFAULT_LINE_LENGTH
15 from black.debug import DebugVisitor
16 from black.mode import TargetVersion
17 from black.output import diff, err, out
19 from . import conftest
22 ALLOWED_SUFFIXES = (PYTHON_SUFFIX, ".pyi", ".out", ".diff", ".ipynb")
24 THIS_DIR = Path(__file__).parent
25 DATA_DIR = THIS_DIR / "data"
26 PROJECT_ROOT = THIS_DIR.parent
27 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
28 DETERMINISTIC_HEADER = "[Deterministic header]"
37 DEFAULT_MODE = black.Mode()
38 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
39 fs = partial(black.format_str, mode=DEFAULT_MODE)
44 mode: black.Mode = field(default_factory=black.Mode)
46 minimum_version: Optional[Tuple[int, int]] = None
49 def _assert_format_equal(expected: str, actual: str) -> None:
50 if actual != expected and (conftest.PRINT_FULL_TREE or conftest.PRINT_TREE_DIFF):
51 bdv: DebugVisitor[Any]
53 expected_out: str = ""
54 if conftest.PRINT_FULL_TREE:
55 out("Expected tree:", fg="green")
57 exp_node = black.lib2to3_parse(expected)
58 bdv = DebugVisitor(print_output=conftest.PRINT_FULL_TREE)
59 list(bdv.visit(exp_node))
60 expected_out = "\n".join(bdv.list_output)
61 except Exception as ve:
63 if conftest.PRINT_FULL_TREE:
64 out("Actual tree:", fg="red")
66 exp_node = black.lib2to3_parse(actual)
67 bdv = DebugVisitor(print_output=conftest.PRINT_FULL_TREE)
68 list(bdv.visit(exp_node))
69 actual_out = "\n".join(bdv.list_output)
70 except Exception as ve:
72 if conftest.PRINT_TREE_DIFF:
75 diff(expected_out, actual_out, "expected tree", "actual tree")
76 or "Trees do not differ"
79 if actual != expected:
80 out(diff(expected, actual, "expected", "actual"))
82 assert actual == expected
85 class FormatFailure(Exception):
86 """Used to wrap failures when assert_format() runs in an extra mode."""
92 mode: black.Mode = DEFAULT_MODE,
95 minimum_version: Optional[Tuple[int, int]] = None,
97 """Convenience function to check that Black formats as expected.
99 You can pass @minimum_version if you're passing code with newer syntax to guard
100 safety guards so they don't just crash with a SyntaxError. Please note this is
101 separate from TargetVerson Mode configuration.
103 _assert_format_inner(
104 source, expected, mode, fast=fast, minimum_version=minimum_version
107 # For both preview and non-preview tests, ensure that Black doesn't crash on
108 # this code, but don't pass "expected" because the precise output may differ.
110 _assert_format_inner(
113 replace(mode, preview=not mode.preview),
115 minimum_version=minimum_version,
117 except Exception as e:
118 text = "non-preview" if mode.preview else "preview"
120 f"Black crashed formatting this case in {text} mode."
122 # Similarly, setting line length to 1 is a good way to catch
123 # stability bugs. But only in non-preview mode because preview mode
124 # currently has a lot of line length 1 bugs.
126 _assert_format_inner(
129 replace(mode, preview=False, line_length=1),
131 minimum_version=minimum_version,
133 except Exception as e:
135 "Black crashed formatting this case with line-length set to 1."
139 def _assert_format_inner(
141 expected: Optional[str] = None,
142 mode: black.Mode = DEFAULT_MODE,
145 minimum_version: Optional[Tuple[int, int]] = None,
147 actual = black.format_str(source, mode=mode)
148 if expected is not None:
149 _assert_format_equal(expected, actual)
150 # It's not useful to run safety checks if we're expecting no changes anyway. The
151 # assertion right above will raise if reality does actually make changes. This just
152 # avoids wasted CPU cycles.
153 if not fast and source != actual:
154 # Unfortunately the AST equivalence check relies on the built-in ast module
155 # being able to parse the code being formatted. This doesn't always work out
156 # when checking modern code on older versions.
157 if minimum_version is None or sys.version_info >= minimum_version:
158 black.assert_equivalent(source, actual)
159 black.assert_stable(source, actual, mode=mode)
162 def dump_to_stderr(*output: str) -> str:
163 return "\n" + "\n".join(output) + "\n"
166 class BlackBaseTestCase(unittest.TestCase):
167 def assertFormatEqual(self, expected: str, actual: str) -> None:
168 _assert_format_equal(expected, actual)
171 def get_base_dir(data: bool) -> Path:
172 return DATA_DIR if data else PROJECT_ROOT
175 def all_data_cases(subdir_name: str, data: bool = True) -> List[str]:
176 cases_dir = get_base_dir(data) / subdir_name
177 assert cases_dir.is_dir()
178 return [case_path.stem for case_path in cases_dir.iterdir()]
182 subdir_name: str, name: str, data: bool = True, suffix: str = PYTHON_SUFFIX
184 """Get case path from name"""
185 case_path = get_base_dir(data) / subdir_name / name
186 if not name.endswith(ALLOWED_SUFFIXES):
187 case_path = case_path.with_suffix(suffix)
188 assert case_path.is_file(), f"{case_path} is not a file."
192 def read_data_with_mode(
193 subdir_name: str, name: str, data: bool = True
194 ) -> Tuple[TestCaseArgs, str, str]:
195 """read_data_with_mode('test_name') -> Mode(), 'input', 'output'"""
196 return read_data_from_file(get_case_path(subdir_name, name, data))
199 def read_data(subdir_name: str, name: str, data: bool = True) -> Tuple[str, str]:
200 """read_data('test_name') -> 'input', 'output'"""
201 _, input, output = read_data_with_mode(subdir_name, name, data)
205 def _parse_minimum_version(version: str) -> Tuple[int, int]:
206 major, minor = version.split(".")
207 return int(major), int(minor)
210 @functools.lru_cache()
211 def get_flags_parser() -> argparse.ArgumentParser:
212 parser = argparse.ArgumentParser()
216 type=lambda val: TargetVersion[val.upper()],
219 parser.add_argument("--line-length", default=DEFAULT_LINE_LENGTH, type=int)
221 "--skip-string-normalization", default=False, action="store_true"
223 parser.add_argument("--pyi", default=False, action="store_true")
224 parser.add_argument("--ipynb", default=False, action="store_true")
226 "--skip-magic-trailing-comma", default=False, action="store_true"
228 parser.add_argument("--preview", default=False, action="store_true")
229 parser.add_argument("--fast", default=False, action="store_true")
232 type=_parse_minimum_version,
235 "Minimum version of Python where this test case is parseable. If this is"
236 " set, the test case will be run twice: once with the specified"
237 " --target-version, and once with --target-version set to exactly the"
238 " specified version. This ensures that Black's autodetection of the target"
239 " version works correctly."
245 def parse_mode(flags_line: str) -> TestCaseArgs:
246 parser = get_flags_parser()
247 args = parser.parse_args(shlex.split(flags_line))
249 target_versions=set(args.target_version),
250 line_length=args.line_length,
251 string_normalization=not args.skip_string_normalization,
254 magic_trailing_comma=not args.skip_magic_trailing_comma,
255 preview=args.preview,
257 return TestCaseArgs(mode=mode, fast=args.fast, minimum_version=args.minimum_version)
260 def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]:
261 with open(file_name, "r", encoding="utf8") as test:
262 lines = test.readlines()
263 _input: List[str] = []
264 _output: List[str] = []
266 mode = TestCaseArgs()
268 if not _input and line.startswith("# flags: "):
269 mode = parse_mode(line[len("# flags: ") :])
271 line = line.replace(EMPTY_LINE, "")
272 if line.rstrip() == "# output":
277 if _input and not _output:
278 # If there's no output marker, treat the entire file as already pre-formatted.
280 return mode, "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
283 def read_jupyter_notebook(subdir_name: str, name: str, data: bool = True) -> str:
284 return read_jupyter_notebook_from_file(
285 get_case_path(subdir_name, name, data, suffix=".ipynb")
289 def read_jupyter_notebook_from_file(file_name: Path) -> str:
290 with open(file_name, mode="rb") as fd:
291 content_bytes = fd.read()
292 return content_bytes.decode()
296 def change_directory(path: Path) -> Iterator[None]:
297 """Context manager to temporarily chdir to a different directory."""
298 previous_dir = os.getcwd()
303 os.chdir(previous_dir)