X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/d0e06b53b09248be34c1d5c0fa8f050bff1d201c..bb588073ab286a9f1f8d839ab2cebe13011dd22c:/tests/util.py?ds=inline diff --git a/tests/util.py b/tests/util.py index 3670952..a31ae09 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,64 +1,273 @@ +import argparse +import functools import os +import shlex +import sys import unittest +from contextlib import contextmanager +from dataclasses import dataclass, field, replace +from functools import partial from pathlib import Path -from typing import List, Tuple, Any +from typing import Any, Iterator, List, Optional, Tuple + import black -from functools import partial +from black.const import DEFAULT_LINE_LENGTH +from black.debug import DebugVisitor +from black.mode import TargetVersion +from black.output import diff, err, out + +from . import conftest + +PYTHON_SUFFIX = ".py" +ALLOWED_SUFFIXES = (PYTHON_SUFFIX, ".pyi", ".out", ".diff", ".ipynb") THIS_DIR = Path(__file__).parent +DATA_DIR = THIS_DIR / "data" PROJECT_ROOT = THIS_DIR.parent EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)" DETERMINISTIC_HEADER = "[Deterministic header]" +PY36_VERSIONS = { + TargetVersion.PY36, + TargetVersion.PY37, + TargetVersion.PY38, + TargetVersion.PY39, +} DEFAULT_MODE = black.Mode() ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True) fs = partial(black.format_str, mode=DEFAULT_MODE) +@dataclass +class TestCaseArgs: + mode: black.Mode = field(default_factory=black.Mode) + fast: bool = False + minimum_version: Optional[Tuple[int, int]] = None + + +def _assert_format_equal(expected: str, actual: str) -> None: + if actual != expected and (conftest.PRINT_FULL_TREE or conftest.PRINT_TREE_DIFF): + bdv: DebugVisitor[Any] + actual_out: str = "" + expected_out: str = "" + if conftest.PRINT_FULL_TREE: + out("Expected tree:", fg="green") + try: + exp_node = black.lib2to3_parse(expected) + bdv = DebugVisitor(print_output=conftest.PRINT_FULL_TREE) + list(bdv.visit(exp_node)) + expected_out = "\n".join(bdv.list_output) + except Exception as ve: + err(str(ve)) + if conftest.PRINT_FULL_TREE: + out("Actual tree:", fg="red") + try: + exp_node = black.lib2to3_parse(actual) + bdv = DebugVisitor(print_output=conftest.PRINT_FULL_TREE) + list(bdv.visit(exp_node)) + actual_out = "\n".join(bdv.list_output) + except Exception as ve: + err(str(ve)) + if conftest.PRINT_TREE_DIFF: + out("Tree Diff:") + out( + diff(expected_out, actual_out, "expected tree", "actual tree") + or "Trees do not differ" + ) + + if actual != expected: + out(diff(expected, actual, "expected", "actual")) + + assert actual == expected + + +class FormatFailure(Exception): + """Used to wrap failures when assert_format() runs in an extra mode.""" + + +def assert_format( + source: str, + expected: str, + mode: black.Mode = DEFAULT_MODE, + *, + fast: bool = False, + minimum_version: Optional[Tuple[int, int]] = None, +) -> None: + """Convenience function to check that Black formats as expected. + + You can pass @minimum_version if you're passing code with newer syntax to guard + safety guards so they don't just crash with a SyntaxError. Please note this is + separate from TargetVerson Mode configuration. + """ + _assert_format_inner( + source, expected, mode, fast=fast, minimum_version=minimum_version + ) + + # For both preview and non-preview tests, ensure that Black doesn't crash on + # this code, but don't pass "expected" because the precise output may differ. + try: + _assert_format_inner( + source, + None, + replace(mode, preview=not mode.preview), + fast=fast, + minimum_version=minimum_version, + ) + except Exception as e: + text = "non-preview" if mode.preview else "preview" + raise FormatFailure( + f"Black crashed formatting this case in {text} mode." + ) from e + # Similarly, setting line length to 1 is a good way to catch + # stability bugs. But only in non-preview mode because preview mode + # currently has a lot of line length 1 bugs. + try: + _assert_format_inner( + source, + None, + replace(mode, preview=False, line_length=1), + fast=fast, + minimum_version=minimum_version, + ) + except Exception as e: + raise FormatFailure( + "Black crashed formatting this case with line-length set to 1." + ) from e + + +def _assert_format_inner( + source: str, + expected: Optional[str] = None, + mode: black.Mode = DEFAULT_MODE, + *, + fast: bool = False, + minimum_version: Optional[Tuple[int, int]] = None, +) -> None: + actual = black.format_str(source, mode=mode) + if expected is not None: + _assert_format_equal(expected, actual) + # It's not useful to run safety checks if we're expecting no changes anyway. The + # assertion right above will raise if reality does actually make changes. This just + # avoids wasted CPU cycles. + if not fast and source != actual: + # Unfortunately the AST equivalence check relies on the built-in ast module + # being able to parse the code being formatted. This doesn't always work out + # when checking modern code on older versions. + if minimum_version is None or sys.version_info >= minimum_version: + black.assert_equivalent(source, actual) + black.assert_stable(source, actual, mode=mode) + + def dump_to_stderr(*output: str) -> str: return "\n" + "\n".join(output) + "\n" class BlackBaseTestCase(unittest.TestCase): - maxDiff = None - _diffThreshold = 2 ** 20 - def assertFormatEqual(self, expected: str, actual: str) -> None: - if actual != expected and not os.environ.get("SKIP_AST_PRINT"): - bdv: black.DebugVisitor[Any] - black.out("Expected tree:", fg="green") - try: - exp_node = black.lib2to3_parse(expected) - bdv = black.DebugVisitor() - list(bdv.visit(exp_node)) - except Exception as ve: - black.err(str(ve)) - black.out("Actual tree:", fg="red") - try: - exp_node = black.lib2to3_parse(actual) - bdv = black.DebugVisitor() - list(bdv.visit(exp_node)) - except Exception as ve: - black.err(str(ve)) - self.assertMultiLineEqual(expected, actual) - - -def read_data(name: str, data: bool = True) -> Tuple[str, str]: + _assert_format_equal(expected, actual) + + +def get_base_dir(data: bool) -> Path: + return DATA_DIR if data else PROJECT_ROOT + + +def all_data_cases(subdir_name: str, data: bool = True) -> List[str]: + cases_dir = get_base_dir(data) / subdir_name + assert cases_dir.is_dir() + return [case_path.stem for case_path in cases_dir.iterdir()] + + +def get_case_path( + subdir_name: str, name: str, data: bool = True, suffix: str = PYTHON_SUFFIX +) -> Path: + """Get case path from name""" + case_path = get_base_dir(data) / subdir_name / name + if not name.endswith(ALLOWED_SUFFIXES): + case_path = case_path.with_suffix(suffix) + assert case_path.is_file(), f"{case_path} is not a file." + return case_path + + +def read_data_with_mode( + subdir_name: str, name: str, data: bool = True +) -> Tuple[TestCaseArgs, str, str]: + """read_data_with_mode('test_name') -> Mode(), 'input', 'output'""" + return read_data_from_file(get_case_path(subdir_name, name, data)) + + +def read_data(subdir_name: str, name: str, data: bool = True) -> Tuple[str, str]: """read_data('test_name') -> 'input', 'output'""" - if not name.endswith((".py", ".pyi", ".out", ".diff")): - name += ".py" - base_dir = THIS_DIR / "data" if data else PROJECT_ROOT - return read_data_from_file(base_dir / name) + _, input, output = read_data_with_mode(subdir_name, name, data) + return input, output + + +def _parse_minimum_version(version: str) -> Tuple[int, int]: + major, minor = version.split(".") + return int(major), int(minor) + +@functools.lru_cache() +def get_flags_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser() + parser.add_argument( + "--target-version", + action="append", + type=lambda val: TargetVersion[val.upper()], + default=(), + ) + parser.add_argument("--line-length", default=DEFAULT_LINE_LENGTH, type=int) + parser.add_argument( + "--skip-string-normalization", default=False, action="store_true" + ) + parser.add_argument("--pyi", default=False, action="store_true") + parser.add_argument("--ipynb", default=False, action="store_true") + parser.add_argument( + "--skip-magic-trailing-comma", default=False, action="store_true" + ) + parser.add_argument("--preview", default=False, action="store_true") + parser.add_argument("--fast", default=False, action="store_true") + parser.add_argument( + "--minimum-version", + type=_parse_minimum_version, + default=None, + help=( + "Minimum version of Python where this test case is parseable. If this is" + " set, the test case will be run twice: once with the specified" + " --target-version, and once with --target-version set to exactly the" + " specified version. This ensures that Black's autodetection of the target" + " version works correctly." + ), + ) + return parser -def read_data_from_file(file_name: Path) -> Tuple[str, str]: + +def parse_mode(flags_line: str) -> TestCaseArgs: + parser = get_flags_parser() + args = parser.parse_args(shlex.split(flags_line)) + mode = black.Mode( + target_versions=set(args.target_version), + line_length=args.line_length, + string_normalization=not args.skip_string_normalization, + is_pyi=args.pyi, + is_ipynb=args.ipynb, + magic_trailing_comma=not args.skip_magic_trailing_comma, + preview=args.preview, + ) + return TestCaseArgs(mode=mode, fast=args.fast, minimum_version=args.minimum_version) + + +def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]: with open(file_name, "r", encoding="utf8") as test: lines = test.readlines() _input: List[str] = [] _output: List[str] = [] result = _input + mode = TestCaseArgs() for line in lines: + if not _input and line.startswith("# flags: "): + mode = parse_mode(line[len("# flags: ") :]) + continue line = line.replace(EMPTY_LINE, "") if line.rstrip() == "# output": result = _output @@ -68,4 +277,27 @@ def read_data_from_file(file_name: Path) -> Tuple[str, str]: if _input and not _output: # If there's no output marker, treat the entire file as already pre-formatted. _output = _input[:] - return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n" + return mode, "".join(_input).strip() + "\n", "".join(_output).strip() + "\n" + + +def read_jupyter_notebook(subdir_name: str, name: str, data: bool = True) -> str: + return read_jupyter_notebook_from_file( + get_case_path(subdir_name, name, data, suffix=".ipynb") + ) + + +def read_jupyter_notebook_from_file(file_name: Path) -> str: + with open(file_name, mode="rb") as fd: + content_bytes = fd.read() + return content_bytes.decode() + + +@contextmanager +def change_directory(path: Path) -> Iterator[None]: + """Context manager to temporarily chdir to a different directory.""" + previous_dir = os.getcwd() + try: + os.chdir(path) + yield + finally: + os.chdir(previous_dir)