All patches and comments are welcome. Please squash your changes to logical
commits before using git-format-patch and git-send-email to
patches@git.madduck.net.
If you'd read over the Git project's submission guidelines and adhered to them,
I'd be especially grateful.
3 from contextlib import contextmanager
4 from pathlib import Path
5 from typing import List, Tuple, Iterator, Any
7 from functools import partial
9 THIS_DIR = Path(__file__).parent
10 PROJECT_ROOT = THIS_DIR.parent
11 EMPTY_LINE = "# EMPTY LINE WITH WHITESPACE" + " (this comment will be removed)"
12 DETERMINISTIC_HEADER = "[Deterministic header]"
15 DEFAULT_MODE = black.Mode()
16 ff = partial(black.format_file_in_place, mode=DEFAULT_MODE, fast=True)
17 fs = partial(black.format_str, mode=DEFAULT_MODE)
20 def dump_to_stderr(*output: str) -> str:
21 return "\n" + "\n".join(output) + "\n"
24 class BlackBaseTestCase(unittest.TestCase):
26 _diffThreshold = 2 ** 20
28 def assertFormatEqual(self, expected: str, actual: str) -> None:
29 if actual != expected and not os.environ.get("SKIP_AST_PRINT"):
30 bdv: black.DebugVisitor[Any]
31 black.out("Expected tree:", fg="green")
33 exp_node = black.lib2to3_parse(expected)
34 bdv = black.DebugVisitor()
35 list(bdv.visit(exp_node))
36 except Exception as ve:
38 black.out("Actual tree:", fg="red")
40 exp_node = black.lib2to3_parse(actual)
41 bdv = black.DebugVisitor()
42 list(bdv.visit(exp_node))
43 except Exception as ve:
45 self.assertMultiLineEqual(expected, actual)
49 def skip_if_exception(e: str) -> Iterator[None]:
52 except Exception as exc:
53 if exc.__class__.__name__ == e:
54 unittest.skip(f"Encountered expected exception {exc}, skipping")
59 def read_data(name: str, data: bool = True) -> Tuple[str, str]:
60 """read_data('test_name') -> 'input', 'output'"""
61 if not name.endswith((".py", ".pyi", ".out", ".diff")):
63 base_dir = THIS_DIR / "data" if data else PROJECT_ROOT
64 return read_data_from_file(base_dir / name)
67 def read_data_from_file(file_name: Path) -> Tuple[str, str]:
68 with open(file_name, "r", encoding="utf8") as test:
69 lines = test.readlines()
70 _input: List[str] = []
71 _output: List[str] = []
74 line = line.replace(EMPTY_LINE, "")
75 if line.rstrip() == "# output":
80 if _input and not _output:
81 # If there's no output marker, treat the entire file as already pre-formatted.
83 return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"