X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/23fec8b0f73a142d1796c8ca958230ad52d83c24..4eb822f20cafcc2a62ac88f4f26298d732423bf1:/black.py diff --git a/black.py b/black.py index a48f647..b7cacf7 100644 --- a/black.py +++ b/black.py @@ -37,9 +37,11 @@ from typing import ( Union, cast, ) +from typing_extensions import Final +from mypy_extensions import mypyc_attr from appdirs import user_cache_dir -from attr import dataclass, evolve, Factory +from dataclasses import dataclass, field, replace import click import toml from typed_ast import ast3, ast27 @@ -185,7 +187,7 @@ VERSION_TO_FEATURES: Dict[TargetVersion, Set[Feature]] = { @dataclass class FileMode: - target_versions: Set[TargetVersion] = Factory(set) + target_versions: Set[TargetVersion] = field(default_factory=set) line_length: int = DEFAULT_LINE_LENGTH string_normalization: bool = True is_pyi: bool = False @@ -211,6 +213,23 @@ def supports_feature(target_versions: Set[TargetVersion], feature: Feature) -> b return all(feature in VERSION_TO_FEATURES[version] for version in target_versions) +def find_pyproject_toml(path_search_start: str) -> Optional[str]: + """Find the absolute filepath to a pyproject.toml if it exists""" + path_project_root = find_project_root(path_search_start) + path_pyproject_toml = path_project_root / "pyproject.toml" + return str(path_pyproject_toml) if path_pyproject_toml.is_file() else None + + +def parse_pyproject_toml(path_config: str) -> Dict[str, Any]: + """Parse a pyproject toml file, pulling out relevant parts for Black + + If parsing fails, will raise a toml.TomlDecodeError + """ + pyproject_toml = toml.load(path_config) + config = pyproject_toml.get("tool", {}).get("black", {}) + return {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} + + def read_pyproject_toml( ctx: click.Context, param: click.Parameter, value: Union[str, int, bool, None] ) -> Optional[str]: @@ -221,16 +240,12 @@ def read_pyproject_toml( """ assert not isinstance(value, (int, bool)), "Invalid parameter type passed" if not value: - root = find_project_root(ctx.params.get("src", ())) - path = root / "pyproject.toml" - if path.is_file(): - value = str(path) - else: + value = find_pyproject_toml(ctx.params.get("src", ())) + if value is None: return None try: - pyproject_toml = toml.load(value) - config = pyproject_toml.get("tool", {}).get("black", {}) + config = parse_pyproject_toml(value) except (toml.TomlDecodeError, OSError) as e: raise click.FileError( filename=value, hint=f"Error reading configuration file: {e}" @@ -241,12 +256,21 @@ def read_pyproject_toml( if ctx.default_map is None: ctx.default_map = {} - ctx.default_map.update( # type: ignore # bad types in .pyi - {k.replace("--", "").replace("-", "_"): v for k, v in config.items()} - ) + ctx.default_map.update(config) # type: ignore # bad types in .pyi return value +def target_version_option_callback( + c: click.Context, p: Union[click.Option, click.Parameter], v: Tuple[str, ...] +) -> List[TargetVersion]: + """Compute the target versions from a --target-version flag. + + This is its own function because mypy couldn't infer the type correctly + when it was a lambda, causing mypyc trouble. + """ + return [TargetVersion[val.upper()] for val in v] + + @click.command(context_settings=dict(help_option_names=["-h", "--help"])) @click.option("-c", "--code", type=str, help="Format the code passed in as a string.") @click.option( @@ -261,7 +285,7 @@ def read_pyproject_toml( "-t", "--target-version", type=click.Choice([v.name.lower() for v in TargetVersion]), - callback=lambda c, p, v: [TargetVersion[val.upper()] for val in v], + callback=target_version_option_callback, multiple=True, help=( "Python versions that should be supported by Black's output. [default: " @@ -388,7 +412,7 @@ def main( verbose: bool, include: str, exclude: str, - src: Tuple[str], + src: Tuple[str, ...], config: Optional[str], ) -> None: """The uncompromising code formatter.""" @@ -429,7 +453,7 @@ def main( except re.error: err(f"Invalid regular expression for exclude given: {exclude!r}") ctx.exit(2) - report = Report(check=check, quiet=quiet, verbose=verbose) + report = Report(check=check, diff=diff, quiet=quiet, verbose=verbose) root = find_project_root(src) sources: Set[Path] = set() path_empty(src, quiet, verbose, ctx) @@ -470,7 +494,9 @@ def main( ctx.exit(report.return_code) -def path_empty(src: Tuple[str], quiet: bool, verbose: bool, ctx: click.Context) -> None: +def path_empty( + src: Tuple[str, ...], quiet: bool, verbose: bool, ctx: click.Context +) -> None: """ Exit if there is no `src` provided for formatting """ @@ -585,7 +611,7 @@ async def schedule_formatting( ): src for src in sorted(sources) } - pending: Iterable[asyncio.Future] = tasks.keys() + pending: Iterable["asyncio.Future[bool]"] = tasks.keys() try: loop.add_signal_handler(signal.SIGINT, cancel, pending) loop.add_signal_handler(signal.SIGTERM, cancel, pending) @@ -629,7 +655,7 @@ def format_file_in_place( `mode` and `fast` options are passed to :func:`format_file_contents`. """ if src.suffix == ".pyi": - mode = evolve(mode, is_pyi=True) + mode = replace(mode, is_pyi=True) then = datetime.utcfromtimestamp(src.stat().st_mtime) with open(src, "rb") as buf: @@ -639,10 +665,10 @@ def format_file_in_place( except NothingChanged: return False - if write_back == write_back.YES: + if write_back == WriteBack.YES: with open(src, "w", encoding=encoding, newline=newline) as f: f.write(dst_contents) - elif write_back == write_back.DIFF: + elif write_back == WriteBack.DIFF: now = datetime.utcnow() src_name = f"{src}\t{then} +0000" dst_name = f"{src}\t{now} +0000" @@ -745,11 +771,9 @@ def format_str(src_contents: str, *, mode: FileMode) -> FileContent: if supports_feature(versions, feature) } for current_line in lines.visit(src_node): - for _ in range(after): - dst_contents.append(str(empty_line)) + dst_contents.append(str(empty_line) * after) before, after = elt.maybe_empty_lines(current_line) - for _ in range(before): - dst_contents.append(str(empty_line)) + dst_contents.append(str(empty_line) * before) for line in split_line( current_line, line_length=mode.line_length, features=split_line_features ): @@ -865,8 +889,16 @@ class Visitor(Generic[T]): if node.type < 256: name = token.tok_name[node.type] else: - name = type_repr(node.type) - yield from getattr(self, f"visit_{name}", self.visit_default)(node) + name = str(type_repr(node.type)) + # We explicitly branch on whether a visitor exists (instead of + # using self.visit_default as the default arg to getattr) in order + # to save needing to create a bound method object and so mypyc can + # generate a native call to visit_default. + visitf = getattr(self, f"visit_{name}", None) + if visitf: + yield from visitf(node) + else: + yield from self.visit_default(node) def visit_default(self, node: LN) -> Iterator[T]: """Default `visit_*()` implementation. Recurses to children of `node`.""" @@ -911,8 +943,8 @@ class DebugVisitor(Visitor[T]): list(v.visit(code)) -WHITESPACE = {token.DEDENT, token.INDENT, token.NEWLINE} -STATEMENT = { +WHITESPACE: Final = {token.DEDENT, token.INDENT, token.NEWLINE} +STATEMENT: Final = { syms.if_stmt, syms.while_stmt, syms.for_stmt, @@ -922,10 +954,10 @@ STATEMENT = { syms.funcdef, syms.classdef, } -STANDALONE_COMMENT = 153 +STANDALONE_COMMENT: Final = 153 token.tok_name[STANDALONE_COMMENT] = "STANDALONE_COMMENT" -LOGIC_OPERATORS = {"and", "or"} -COMPARATORS = { +LOGIC_OPERATORS: Final = {"and", "or"} +COMPARATORS: Final = { token.LESS, token.GREATER, token.EQEQUAL, @@ -933,7 +965,7 @@ COMPARATORS = { token.LESSEQUAL, token.GREATEREQUAL, } -MATH_OPERATORS = { +MATH_OPERATORS: Final = { token.VBAR, token.CIRCUMFLEX, token.AMPER, @@ -949,23 +981,23 @@ MATH_OPERATORS = { token.TILDE, token.DOUBLESTAR, } -STARS = {token.STAR, token.DOUBLESTAR} -VARARGS_SPECIALS = STARS | {token.SLASH} -VARARGS_PARENTS = { +STARS: Final = {token.STAR, token.DOUBLESTAR} +VARARGS_SPECIALS: Final = STARS | {token.SLASH} +VARARGS_PARENTS: Final = { syms.arglist, syms.argument, # double star in arglist syms.trailer, # single argument to call syms.typedargslist, syms.varargslist, # lambdas } -UNPACKING_PARENTS = { +UNPACKING_PARENTS: Final = { syms.atom, # single element of a list or set literal syms.dictsetmaker, syms.listmaker, syms.testlist_gexp, syms.testlist_star_expr, } -TEST_DESCENDANTS = { +TEST_DESCENDANTS: Final = { syms.test, syms.lambdef, syms.or_test, @@ -982,7 +1014,7 @@ TEST_DESCENDANTS = { syms.term, syms.power, } -ASSIGNMENTS = { +ASSIGNMENTS: Final = { "=", "+=", "-=", @@ -998,13 +1030,13 @@ ASSIGNMENTS = { "**=", "//=", } -COMPREHENSION_PRIORITY = 20 -COMMA_PRIORITY = 18 -TERNARY_PRIORITY = 16 -LOGIC_PRIORITY = 14 -STRING_PRIORITY = 12 -COMPARATOR_PRIORITY = 10 -MATH_PRIORITIES = { +COMPREHENSION_PRIORITY: Final = 20 +COMMA_PRIORITY: Final = 18 +TERNARY_PRIORITY: Final = 16 +LOGIC_PRIORITY: Final = 14 +STRING_PRIORITY: Final = 12 +COMPARATOR_PRIORITY: Final = 10 +MATH_PRIORITIES: Final = { token.VBAR: 9, token.CIRCUMFLEX: 8, token.AMPER: 7, @@ -1020,7 +1052,7 @@ MATH_PRIORITIES = { token.TILDE: 3, token.DOUBLESTAR: 2, } -DOT_PRIORITY = 1 +DOT_PRIORITY: Final = 1 @dataclass @@ -1028,11 +1060,11 @@ class BracketTracker: """Keeps track of brackets on a line.""" depth: int = 0 - bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = Factory(dict) - delimiters: Dict[LeafID, Priority] = Factory(dict) + bracket_match: Dict[Tuple[Depth, NodeType], Leaf] = field(default_factory=dict) + delimiters: Dict[LeafID, Priority] = field(default_factory=dict) previous: Optional[Leaf] = None - _for_loop_depths: List[int] = Factory(list) - _lambda_argument_depths: List[int] = Factory(list) + _for_loop_depths: List[int] = field(default_factory=list) + _lambda_argument_depths: List[int] = field(default_factory=list) def mark(self, leaf: Leaf) -> None: """Mark `leaf` with bracket-related metadata. Keep track of delimiters. @@ -1160,9 +1192,10 @@ class Line: """Holds leaves and comments. Can be printed with `str(line)`.""" depth: int = 0 - leaves: List[Leaf] = Factory(list) - comments: Dict[LeafID, List[Leaf]] = Factory(dict) # keys ordered like `leaves` - bracket_tracker: BracketTracker = Factory(BracketTracker) + leaves: List[Leaf] = field(default_factory=list) + # keys ordered like `leaves` + comments: Dict[LeafID, List[Leaf]] = field(default_factory=dict) + bracket_tracker: BracketTracker = field(default_factory=BracketTracker) inside_brackets: bool = False should_explode: bool = False @@ -1383,7 +1416,10 @@ class Line: for leaf_id, comments in self.comments.items(): for comment in comments: if is_type_comment(comment): - if leaf_id not in ignored_ids or comment_seen: + if comment_seen or ( + not is_type_comment(comment, " ignore") + and leaf_id not in ignored_ids + ): return True comment_seen = True @@ -1422,11 +1458,7 @@ class Line: return False def contains_multiline_strings(self) -> bool: - for leaf in self.leaves: - if is_multiline_string(leaf): - return True - - return False + return any(is_multiline_string(leaf) for leaf in self.leaves) def maybe_remove_trailing_comma(self, closing: Leaf) -> bool: """Remove trailing comma if there is one and it's safe.""" @@ -1565,7 +1597,7 @@ class EmptyLineTracker: is_pyi: bool = False previous_line: Optional[Line] = None previous_after: int = 0 - previous_defs: List[int] = Factory(list) + previous_defs: List[int] = field(default_factory=list) def maybe_empty_lines(self, current_line: Line) -> Tuple[int, int]: """Return the number of extra empty lines before and after the `current_line`. @@ -1679,7 +1711,7 @@ class LineGenerator(Visitor[Line]): is_pyi: bool = False normalize_strings: bool = True - current_line: Line = Factory(Line) + current_line: Line = field(default_factory=Line) remove_u_prefix: bool = False def line(self, indent: int = 0) -> Iterator[Line]: @@ -1728,13 +1760,13 @@ class LineGenerator(Visitor[Line]): self.current_line.append(node) yield from super().visit_default(node) - def visit_INDENT(self, node: Node) -> Iterator[Line]: + def visit_INDENT(self, node: Leaf) -> Iterator[Line]: """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. yield from self.line(+1) yield from self.visit_default(node) - def visit_DEDENT(self, node: Node) -> Iterator[Line]: + def visit_DEDENT(self, node: Leaf) -> Iterator[Line]: """Decrease indentation level, maybe yield a line.""" # The current line might still wait for trailing comments. At DEDENT time # there won't be any (they would be prefixes on the preceding NEWLINE). @@ -1844,7 +1876,7 @@ class LineGenerator(Visitor[Line]): node.insert_child(index, Node(syms.atom, [lpar, operand, rpar])) yield from self.visit_default(node) - def __attrs_post_init__(self) -> None: + def __post_init__(self) -> None: """You are in a twisty little maze of passages.""" v = self.visit_stmt Ø: Set[str] = set() @@ -2462,7 +2494,7 @@ def left_hand_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ body_leaves: List[Leaf] = [] head_leaves: List[Leaf] = [] current_leaves = head_leaves - matching_bracket = None + matching_bracket: Optional[Leaf] = None for leaf in line.leaves: if ( current_leaves is body_leaves @@ -2505,8 +2537,8 @@ def right_hand_split( body_leaves: List[Leaf] = [] head_leaves: List[Leaf] = [] current_leaves = tail_leaves - opening_bracket = None - closing_bracket = None + opening_bracket: Optional[Leaf] = None + closing_bracket: Optional[Leaf] = None for leaf in reversed(line.leaves): if current_leaves is body_leaves: if leaf is opening_bracket: @@ -2810,7 +2842,7 @@ def normalize_string_prefix(leaf: Leaf, remove_u_prefix: bool = False) -> None: match = re.match(r"^([furbFURB]*)(.*)$", leaf.value, re.DOTALL) assert match is not None, f"failed to match string {leaf.value!r}" orig_prefix = match.group(1) - new_prefix = orig_prefix.lower() + new_prefix = orig_prefix.replace("F", "f").replace("B", "b").replace("U", "u") if remove_u_prefix: new_prefix = new_prefix.replace("u", "") leaf.value = f"{new_prefix}{match.group(2)}" @@ -3027,7 +3059,7 @@ def convert_one_fmt_off_pair(node: Node) -> bool: # That happens when one of the `ignored_nodes` ended with a NEWLINE # leaf (possibly followed by a DEDENT). hidden_value = hidden_value[:-1] - first_idx = None + first_idx: Optional[int] = None for ignored in ignored_nodes: index = ignored.remove() if first_idx is None: @@ -3398,8 +3430,8 @@ def generate_trailers_to_omit(line: Line, line_length: int) -> Iterator[Set[Leaf yield omit length = 4 * line.depth - opening_bracket = None - closing_bracket = None + opening_bracket: Optional[Leaf] = None + closing_bracket: Optional[Leaf] = None inner_brackets: Set[LeafID] = set() for index, leaf, leaf_length in enumerate_with_length(line, reversed=True): length += leaf_length @@ -3570,7 +3602,7 @@ def find_project_root(srcs: Iterable[str]) -> Path: # Append a fake file so `parents` below returns `common_base_dir`, too. common_base /= "fake-file" for directory in common_base.parents: - if (directory / ".git").is_dir(): + if (directory / ".git").exists(): return directory if (directory / ".hg").is_dir(): @@ -3587,6 +3619,7 @@ class Report: """Provides a reformatting counter. Can be rendered with `str(report)`.""" check: bool = False + diff: bool = False quiet: bool = False verbose: bool = False change_count: int = 0 @@ -3596,7 +3629,7 @@ class Report: def done(self, src: Path, changed: Changed) -> None: """Increment the counter for successful reformatting. Write out a message.""" if changed is Changed.YES: - reformatted = "would reformat" if self.check else "reformatted" + reformatted = "would reformat" if self.check or self.diff else "reformatted" if self.verbose or not self.quiet: out(f"{reformatted} {src}") self.change_count += 1 @@ -3642,7 +3675,7 @@ class Report: Use `click.unstyle` to remove colors. """ - if self.check: + if self.check or self.diff: reformatted = "would be reformatted" unchanged = "would be left unchanged" failed = "would fail to reformat" @@ -3712,7 +3745,7 @@ def assert_equivalent(src: str, dst: str) -> None: yield f"{' ' * depth}{node.__class__.__name__}(" - for field in sorted(node._fields): + for field in sorted(node._fields): # noqa: F402 # TypeIgnore has only one field 'lineno' which breaks this comparison type_ignore_classes = (ast3.TypeIgnore, ast27.TypeIgnore) if sys.version_info >= (3, 8): @@ -3796,6 +3829,7 @@ def assert_stable(src: str, dst: str, mode: FileMode) -> None: ) from None +@mypyc_attr(patchable=True) def dump_to_file(*output: str) -> str: """Dump `output` to a temporary file. Return path to the file.""" with tempfile.NamedTemporaryFile( @@ -3828,7 +3862,7 @@ def diff(a: str, b: str, a_name: str, b_name: str) -> str: ) -def cancel(tasks: Iterable[asyncio.Task]) -> None: +def cancel(tasks: Iterable["asyncio.Task[Any]"]) -> None: """asyncio signal handler that cancels all `tasks` and reports to stderr.""" err("Aborted!") for task in tasks: