import os import unittest from pathlib import Path from typing import List, Tuple, Any from functools import partial import black from black.output import out, err from black.debug import DebugVisitor THIS_DIR = Path(__file__).parent PROJECT_ROOT = THIS_DIR.parent EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)" DETERMINISTIC_HEADER = "[Deterministic header]" DEFAULT_MODE = black.Mode() ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True) fs = partial(black.format_str, mode=DEFAULT_MODE) def dump_to_stderr(*output: str) -> str: return "\n" + "\n".join(output) + "\n" class BlackBaseTestCase(unittest.TestCase): maxDiff = None _diffThreshold = 2 ** 20 def assertFormatEqual(self, expected: str, actual: str) -> None: if actual != expected and not os.environ.get("SKIP_AST_PRINT"): bdv: DebugVisitor[Any] out("Expected tree:", fg="green") try: exp_node = black.lib2to3_parse(expected) bdv = DebugVisitor() list(bdv.visit(exp_node)) except Exception as ve: err(str(ve)) out("Actual tree:", fg="red") try: exp_node = black.lib2to3_parse(actual) bdv = DebugVisitor() list(bdv.visit(exp_node)) except Exception as ve: err(str(ve)) self.assertMultiLineEqual(expected, actual) def read_data(name: str, data: bool = True) -> Tuple[str, str]: """read_data('test_name') -> 'input', 'output'""" if not name.endswith((".py", ".pyi", ".out", ".diff")): name += ".py" base_dir = THIS_DIR / "data" if data else PROJECT_ROOT return read_data_from_file(base_dir / name) def read_data_from_file(file_name: Path) -> Tuple[str, str]: with open(file_name, "r", encoding="utf8") as test: lines = test.readlines() _input: List[str] = [] _output: List[str] = [] result = _input for line in lines: line = line.replace(EMPTY_LINE, "") if line.rstrip() == "# output": result = _output continue result.append(line) if _input and not _output: # If there's no output marker, treat the entire file as already pre-formatted. _output = _input[:] return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"