X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/407052724fa1c97ee8bcd4e96de650def00be03e..e76adbecb8c3b62631868332c3b632363c7c16b4:/tests/util.py diff --git a/tests/util.py b/tests/util.py index 9c3d3cb..e83017f 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,7 +1,13 @@ +import os import unittest -from contextlib import contextmanager from pathlib import Path -from typing import List, Tuple, Iterator +from typing import Iterator, List, Tuple, Any +from contextlib import contextmanager +from functools import partial + +import black +from black.output import out, err +from black.debug import DebugVisitor THIS_DIR = Path(__file__).parent PROJECT_ROOT = THIS_DIR.parent @@ -9,26 +15,52 @@ EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)" DETERMINISTIC_HEADER = "[Deterministic header]" -@contextmanager -def skip_if_exception(e: str) -> Iterator[None]: - try: - yield - except Exception as exc: - if exc.__class__.__name__ == e: - unittest.skip(f"Encountered expected exception {exc}, skipping") - else: - raise +DEFAULT_MODE = black.Mode() +ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True) +fs = partial(black.format_str, mode=DEFAULT_MODE) + + +def dump_to_stderr(*output: str) -> str: + return "\n" + "\n".join(output) + "\n" + + +class BlackBaseTestCase(unittest.TestCase): + maxDiff = None + _diffThreshold = 2 ** 20 + + def assertFormatEqual(self, expected: str, actual: str) -> None: + if actual != expected and not os.environ.get("SKIP_AST_PRINT"): + bdv: DebugVisitor[Any] + out("Expected tree:", fg="green") + try: + exp_node = black.lib2to3_parse(expected) + bdv = DebugVisitor() + list(bdv.visit(exp_node)) + except Exception as ve: + err(str(ve)) + out("Actual tree:", fg="red") + try: + exp_node = black.lib2to3_parse(actual) + bdv = DebugVisitor() + list(bdv.visit(exp_node)) + except Exception as ve: + err(str(ve)) + self.assertMultiLineEqual(expected, actual) def read_data(name: str, data: bool = True) -> Tuple[str, str]: """read_data('test_name') -> 'input', 'output'""" if not name.endswith((".py", ".pyi", ".out", ".diff")): name += ".py" - _input: List[str] = [] - _output: List[str] = [] base_dir = THIS_DIR / "data" if data else PROJECT_ROOT - with open(base_dir / name, "r", encoding="utf8") as test: + return read_data_from_file(base_dir / name) + + +def read_data_from_file(file_name: Path) -> Tuple[str, str]: + with open(file_name, "r", encoding="utf8") as test: lines = test.readlines() + _input: List[str] = [] + _output: List[str] = [] result = _input for line in lines: line = line.replace(EMPTY_LINE, "") @@ -41,3 +73,14 @@ def read_data(name: str, data: bool = True) -> Tuple[str, str]: # If there's no output marker, treat the entire file as already pre-formatted. _output = _input[:] return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n" + + +@contextmanager +def change_directory(path: Path) -> Iterator[None]: + """Context manager to temporarily chdir to a different directory.""" + previous_dir = os.getcwd() + try: + os.chdir(path) + yield + finally: + os.chdir(previous_dir)