import os
import unittest
-from contextlib import contextmanager
from pathlib import Path
-from typing import List, Tuple, Iterator, Any
-import black
+from typing import Iterator, List, Tuple, Any
+from contextlib import contextmanager
from functools import partial
+import black
+from black.output import out, err
+from black.debug import DebugVisitor
+
THIS_DIR = Path(__file__).parent
PROJECT_ROOT = THIS_DIR.parent
EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
DETERMINISTIC_HEADER = "[Deterministic header]"
-DEFAULT_MODE = black.FileMode(experimental_string_processing=True)
+DEFAULT_MODE = black.Mode()
ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
fs = partial(black.format_str, mode=DEFAULT_MODE)
def assertFormatEqual(self, expected: str, actual: str) -> None:
if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
- bdv: black.DebugVisitor[Any]
- black.out("Expected tree:", fg="green")
+ bdv: DebugVisitor[Any]
+ out("Expected tree:", fg="green")
try:
exp_node = black.lib2to3_parse(expected)
- bdv = black.DebugVisitor()
+ bdv = DebugVisitor()
list(bdv.visit(exp_node))
except Exception as ve:
- black.err(str(ve))
- black.out("Actual tree:", fg="red")
+ err(str(ve))
+ out("Actual tree:", fg="red")
try:
exp_node = black.lib2to3_parse(actual)
- bdv = black.DebugVisitor()
+ bdv = DebugVisitor()
list(bdv.visit(exp_node))
except Exception as ve:
- black.err(str(ve))
+ err(str(ve))
self.assertMultiLineEqual(expected, actual)
-@contextmanager
-def skip_if_exception(e: str) -> Iterator[None]:
- try:
- yield
- except Exception as exc:
- if exc.__class__.__name__ == e:
- unittest.skip(f"Encountered expected exception {exc}, skipping")
- else:
- raise
-
-
def read_data(name: str, data: bool = True) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'"""
if not name.endswith((".py", ".pyi", ".out", ".diff")):
# If there's no output marker, treat the entire file as already pre-formatted.
_output = _input[:]
return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
+
+
+@contextmanager
+def change_directory(path: Path) -> Iterator[None]:
+ """Context manager to temporarily chdir to a different directory."""
+ previous_dir = os.getcwd()
+ try:
+ os.chdir(path)
+ yield
+ finally:
+ os.chdir(previous_dir)