X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/d0e06b53b09248be34c1d5c0fa8f050bff1d201c..79575f3376f043186d8b8c4885ef51c6b3c36246:/tests/util.py diff --git a/tests/util.py b/tests/util.py index 3670952..e83017f 100644 --- a/tests/util.py +++ b/tests/util.py @@ -1,10 +1,14 @@ import os import unittest from pathlib import Path -from typing import List, Tuple, Any -import black +from typing import Iterator, List, Tuple, Any +from contextlib import contextmanager from functools import partial +import black +from black.output import out, err +from black.debug import DebugVisitor + THIS_DIR = Path(__file__).parent PROJECT_ROOT = THIS_DIR.parent EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)" @@ -26,21 +30,21 @@ class BlackBaseTestCase(unittest.TestCase): def assertFormatEqual(self, expected: str, actual: str) -> None: if actual != expected and not os.environ.get("SKIP_AST_PRINT"): - bdv: black.DebugVisitor[Any] - black.out("Expected tree:", fg="green") + bdv: DebugVisitor[Any] + out("Expected tree:", fg="green") try: exp_node = black.lib2to3_parse(expected) - bdv = black.DebugVisitor() + bdv = DebugVisitor() list(bdv.visit(exp_node)) except Exception as ve: - black.err(str(ve)) - black.out("Actual tree:", fg="red") + err(str(ve)) + out("Actual tree:", fg="red") try: exp_node = black.lib2to3_parse(actual) - bdv = black.DebugVisitor() + bdv = DebugVisitor() list(bdv.visit(exp_node)) except Exception as ve: - black.err(str(ve)) + err(str(ve)) self.assertMultiLineEqual(expected, actual) @@ -69,3 +73,14 @@ def read_data_from_file(file_name: Path) -> Tuple[str, str]: # If there's no output marker, treat the entire file as already pre-formatted. _output = _input[:] return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n" + + +@contextmanager +def change_directory(path: Path) -> Iterator[None]: + """Context manager to temporarily chdir to a different directory.""" + previous_dir = os.getcwd() + try: + os.chdir(path) + yield + finally: + os.chdir(previous_dir)