import os import unittest from pathlib import Path from typing import Iterator, List, Tuple, Any from contextlib import contextmanager from functools import partial import black from black.output import out, err from black.debug import DebugVisitor THIS_DIR = Path(__file__).parent PROJECT_ROOT = THIS_DIR.parent EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)" DETERMINISTIC_HEADER = "[Deterministic header]" DEFAULT_MODE = black.Mode() ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True) fs = partial(black.format_str, mode=DEFAULT_MODE) def dump_to_stderr(*output: str) -> str: return "\n" + "\n".join(output) + "\n" class BlackBaseTestCase(unittest.TestCase): maxDiff = None _diffThreshold = 2 ** 20 def assertFormatEqual(self, expected: str, actual: str) -> None: if actual != expected and not os.environ.get("SKIP_AST_PRINT"): bdv: DebugVisitor[Any] out("Expected tree:", fg="green") try: exp_node = black.lib2to3_parse(expected) bdv = DebugVisitor() list(bdv.visit(exp_node)) except Exception as ve: err(str(ve)) out("Actual tree:", fg="red") try: exp_node = black.lib2to3_parse(actual) bdv = DebugVisitor() list(bdv.visit(exp_node)) except Exception as ve: err(str(ve)) self.assertMultiLineEqual(expected, actual) def read_data(name: str, data: bool = True) -> Tuple[str, str]: """read_data('test_name') -> 'input', 'output'""" if not name.endswith((".py", ".pyi", ".out", ".diff")): name += ".py" base_dir = THIS_DIR / "data" if data else PROJECT_ROOT return read_data_from_file(base_dir / name) def read_data_from_file(file_name: Path) -> Tuple[str, str]: with open(file_name, "r", encoding="utf8") as test: lines = test.readlines() _input: List[str] = [] _output: List[str] = [] result = _input for line in lines: line = line.replace(EMPTY_LINE, "") if line.rstrip() == "# output": result = _output continue result.append(line) if _input and not _output: # If there's no output marker, treat the entire file as already pre-formatted. _output = _input[:] return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n" @contextmanager def change_directory(path: Path) -> Iterator[None]: """Context manager to temporarily chdir to a different directory.""" previous_dir = os.getcwd() try: os.chdir(path) yield finally: os.chdir(previous_dir)