X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/e6cd10e7615f4df537e2eaefcf3904a4feecad1f..d3670d9c6568aef270ccda3006252ee3388cf587:/tests/util.py diff --git a/tests/util.py b/tests/util.py index da65ed0..1e86a3f 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,18 +1,20 @@ import os import unittest -from contextlib import contextmanager from pathlib import Path -from typing import List, Tuple, Iterator, Any -import black +from typing import List, Tuple, Any from functools import partial +import black +from black.output import out, err +from black.debug import DebugVisitor + THIS_DIR = Path(__file__).parent PROJECT_ROOT = THIS_DIR.parent EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)" DETERMINISTIC_HEADER = "[Deterministic header]" -DEFAULT_MODE = black.FileMode(experimental_string_processing=True) +DEFAULT_MODE = black.Mode() ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True) fs = partial(black.format_str, mode=DEFAULT_MODE) @@ -27,35 +29,24 @@ class BlackBaseTestCase(unittest.TestCase): def assertFormatEqual(self, expected: str, actual: str) -> None: if actual != expected and not os.environ.get("SKIP_AST_PRINT"): - bdv: black.DebugVisitor[Any] - black.out("Expected tree:", fg="green") + bdv: DebugVisitor[Any] + out("Expected tree:", fg="green") try: exp_node = black.lib2to3_parse(expected) - bdv = black.DebugVisitor() + bdv = DebugVisitor() list(bdv.visit(exp_node)) except Exception as ve: - black.err(str(ve)) - black.out("Actual tree:", fg="red") + err(str(ve)) + out("Actual tree:", fg="red") try: exp_node = black.lib2to3_parse(actual) - bdv = black.DebugVisitor() + bdv = DebugVisitor() list(bdv.visit(exp_node)) except Exception as ve: - black.err(str(ve)) + err(str(ve)) self.assertMultiLineEqual(expected, actual) -@contextmanager -def skip_if_exception(e: str) -> Iterator[None]: - try: - yield - except Exception as exc: - if exc.__class__.__name__ == e: - unittest.skip(f"Encountered expected exception {exc}, skipping") - else: - raise - - def read_data(name: str, data: bool = True) -> Tuple[str, str]: """read_data('test_name') -> 'input', 'output'""" if not name.endswith((".py", ".pyi", ".out", ".diff")):