X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/0a833b4b14953f98e81d632281a75318faa66170..53d9bace12b3aa230820c869a079020b4608c945:/tests/util.py

diff --git a/tests/util.py b/tests/util.py
index ad98669..1e86a3f 100644
--- a/tests/util.py
+++ b/tests/util.py
@@ -1,11 +1,13 @@
 import os
 import unittest
-from contextlib import contextmanager
 from pathlib import Path
-from typing import List, Tuple, Iterator, Any
-import black
+from typing import List, Tuple, Any
 from functools import partial
 
+import black
+from black.output import out, err
+from black.debug import DebugVisitor
+
 THIS_DIR = Path(__file__).parent
 PROJECT_ROOT = THIS_DIR.parent
 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
@@ -27,35 +29,24 @@ class BlackBaseTestCase(unittest.TestCase):
 
     def assertFormatEqual(self, expected: str, actual: str) -> None:
         if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
-            bdv: black.DebugVisitor[Any]
-            black.out("Expected tree:", fg="green")
+            bdv: DebugVisitor[Any]
+            out("Expected tree:", fg="green")
             try:
                 exp_node = black.lib2to3_parse(expected)
-                bdv = black.DebugVisitor()
+                bdv = DebugVisitor()
                 list(bdv.visit(exp_node))
             except Exception as ve:
-                black.err(str(ve))
-            black.out("Actual tree:", fg="red")
+                err(str(ve))
+            out("Actual tree:", fg="red")
             try:
                 exp_node = black.lib2to3_parse(actual)
-                bdv = black.DebugVisitor()
+                bdv = DebugVisitor()
                 list(bdv.visit(exp_node))
             except Exception as ve:
-                black.err(str(ve))
+                err(str(ve))
         self.assertMultiLineEqual(expected, actual)
 
 
-@contextmanager
-def skip_if_exception(e: str) -> Iterator[None]:
-    try:
-        yield
-    except Exception as exc:
-        if exc.__class__.__name__ == e:
-            unittest.skip(f"Encountered expected exception {exc}, skipping")
-        else:
-            raise
-
-
 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
     """read_data('test_name') -> 'input', 'output'"""
     if not name.endswith((".py", ".pyi", ".out", ".diff")):