+import argparse
+import functools
import os
+import shlex
import sys
import unittest
from contextlib import contextmanager
-from dataclasses import replace
+from dataclasses import dataclass, field, replace
from functools import partial
from pathlib import Path
from typing import Any, Iterator, List, Optional, Tuple
import black
+from black.const import DEFAULT_LINE_LENGTH
from black.debug import DebugVisitor
from black.mode import TargetVersion
from black.output import diff, err, out
fs = partial(black.format_str, mode=DEFAULT_MODE)
+@dataclass
+class TestCaseArgs:
+ mode: black.Mode = field(default_factory=black.Mode)
+ fast: bool = False
+ minimum_version: Optional[Tuple[int, int]] = None
+
+
def _assert_format_equal(expected: str, actual: str) -> None:
if actual != expected and (conftest.PRINT_FULL_TREE or conftest.PRINT_TREE_DIFF):
bdv: DebugVisitor[Any]
return case_path
+def read_data_with_mode(
+ subdir_name: str, name: str, data: bool = True
+) -> Tuple[TestCaseArgs, str, str]:
+ """read_data_with_mode('test_name') -> Mode(), 'input', 'output'"""
+ return read_data_from_file(get_case_path(subdir_name, name, data))
+
+
def read_data(subdir_name: str, name: str, data: bool = True) -> Tuple[str, str]:
"""read_data('test_name') -> 'input', 'output'"""
- return read_data_from_file(get_case_path(subdir_name, name, data))
+ _, input, output = read_data_with_mode(subdir_name, name, data)
+ return input, output
+
+
+def _parse_minimum_version(version: str) -> Tuple[int, int]:
+ major, minor = version.split(".")
+ return int(major), int(minor)
-def read_data_from_file(file_name: Path) -> Tuple[str, str]:
+@functools.lru_cache()
+def get_flags_parser() -> argparse.ArgumentParser:
+ parser = argparse.ArgumentParser()
+ parser.add_argument(
+ "--target-version",
+ action="append",
+ type=lambda val: TargetVersion[val.upper()],
+ default=(),
+ )
+ parser.add_argument("--line-length", default=DEFAULT_LINE_LENGTH, type=int)
+ parser.add_argument(
+ "--skip-string-normalization", default=False, action="store_true"
+ )
+ parser.add_argument("--pyi", default=False, action="store_true")
+ parser.add_argument("--ipynb", default=False, action="store_true")
+ parser.add_argument(
+ "--skip-magic-trailing-comma", default=False, action="store_true"
+ )
+ parser.add_argument("--preview", default=False, action="store_true")
+ parser.add_argument("--fast", default=False, action="store_true")
+ parser.add_argument(
+ "--minimum-version",
+ type=_parse_minimum_version,
+ default=None,
+ help=(
+ "Minimum version of Python where this test case is parseable. If this is"
+ " set, the test case will be run twice: once with the specified"
+ " --target-version, and once with --target-version set to exactly the"
+ " specified version. This ensures that Black's autodetection of the target"
+ " version works correctly."
+ ),
+ )
+ return parser
+
+
+def parse_mode(flags_line: str) -> TestCaseArgs:
+ parser = get_flags_parser()
+ args = parser.parse_args(shlex.split(flags_line))
+ mode = black.Mode(
+ target_versions=set(args.target_version),
+ line_length=args.line_length,
+ string_normalization=not args.skip_string_normalization,
+ is_pyi=args.pyi,
+ is_ipynb=args.ipynb,
+ magic_trailing_comma=not args.skip_magic_trailing_comma,
+ preview=args.preview,
+ )
+ return TestCaseArgs(mode=mode, fast=args.fast, minimum_version=args.minimum_version)
+
+
+def read_data_from_file(file_name: Path) -> Tuple[TestCaseArgs, str, str]:
with open(file_name, "r", encoding="utf8") as test:
lines = test.readlines()
_input: List[str] = []
_output: List[str] = []
result = _input
+ mode = TestCaseArgs()
for line in lines:
+ if not _input and line.startswith("# flags: "):
+ mode = parse_mode(line[len("# flags: ") :])
+ continue
line = line.replace(EMPTY_LINE, "")
if line.rstrip() == "# output":
result = _output
if _input and not _output:
# If there's no output marker, treat the entire file as already pre-formatted.
_output = _input[:]
- return "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
+ return mode, "".join(_input).strip() + "\n", "".join(_output).strip() + "\n"
def read_jupyter_notebook(subdir_name: str, name: str, data: bool = True) -> str: