+import os
import unittest
-from contextlib import contextmanager
from pathlib import Path
-from typing import List, Tuple, Iterator
+from typing import List, Tuple, Any
+from functools import partial
+
+import black
+from black.output import out, err
+from black.debug import DebugVisitor
THIS_DIR = Path(__file__).parent
PROJECT_ROOT = THIS_DIR.parent
DETERMINISTIC_HEADER = "[Deterministic header]"
-@contextmanager
-def skip_if_exception(e: str) -> Iterator[None]:
- try:
- yield
- except Exception as exc:
- if exc.__class__.__name__ == e:
- unittest.skip(f"Encountered expected exception {exc}, skipping")
- else:
- raise
+DEFAULT_MODE = black.Mode()
+ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
+fs = partial(black.format_str, mode=DEFAULT_MODE)
+
+
+def dump_to_stderr(*output: str) -> str:
+ return "\n" + "\n".join(output) + "\n"
+
+
+class BlackBaseTestCase(unittest.TestCase):
+ maxDiff = None
+ _diffThreshold = 2 ** 20
+
+ def assertFormatEqual(self, expected: str, actual: str) -> None:
+ if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
+ bdv: DebugVisitor[Any]
+ out("Expected tree:", fg="green")
+ try:
+ exp_node = black.lib2to3_parse(expected)
+ bdv = DebugVisitor()
+ list(bdv.visit(exp_node))
+ except Exception as ve:
+ err(str(ve))
+ out("Actual tree:", fg="red")
+ try:
+ exp_node = black.lib2to3_parse(actual)
+ bdv = DebugVisitor()
+ list(bdv.visit(exp_node))
+ except Exception as ve:
+ err(str(ve))
+ self.assertMultiLineEqual(expected, actual)
def read_data(name: str, data: bool = True) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'"""
if not name.endswith((".py", ".pyi", ".out", ".diff")):
name += ".py"
- _input: List[str] = []
- _output: List[str] = []
base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
- with open(base_dir / name, "r", encoding="utf8") as test:
+ return read_data_from_file(base_dir / name)
+
+
+def read_data_from_file(file_name: Path) -> Tuple[str, str]:
+ with open(file_name, "r", encoding="utf8") as test:
lines = test.readlines()
+ _input: List[str] = []
+ _output: List[str] = []
result = _input
for line in lines:
line = line.replace(EMPTY_LINE, "")