X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/75f99bded33abe962ca08bf16c77635ac9ca00a1..133af572072bf7bc92c23a609773c2ea66e483b7:/src/black/linegen.py diff --git a/src/black/linegen.py b/src/black/linegen.py index caffbab..5ef3bbd 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -1,38 +1,80 @@ """ Generating lines of code. """ -from functools import partial, wraps import sys +from dataclasses import replace +from enum import Enum, auto +from functools import partial, wraps from typing import Collection, Iterator, List, Optional, Set, Union, cast -from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT -from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS -from black.nodes import Visitor, syms, is_arith_like, ensure_visible +from black.brackets import ( + COMMA_PRIORITY, + DOT_PRIORITY, + get_leaves_inside_matching_brackets, + max_delimiter_priority_in_atom, +) +from black.comments import FMT_OFF, generate_comments, list_comments +from black.lines import ( + Line, + RHSResult, + append_leaves, + can_be_split, + can_omit_invisible_parens, + is_line_short_enough, + line_to_string, +) +from black.mode import Feature, Mode, Preview from black.nodes import ( + ASSIGNMENTS, + BRACKETS, + CLOSING_BRACKETS, + OPENING_BRACKETS, + RARROW, + STANDALONE_COMMENT, + STATEMENT, + WHITESPACE, + Visitor, + ensure_visible, + is_arith_like, + is_async_stmt_or_funcdef, + is_atom_with_invisible_parens, is_docstring, is_empty_tuple, - is_one_tuple, + is_lpar_token, + is_multiline_string, + is_name_token, is_one_sequence_between, + is_one_tuple, + is_rpar_token, + is_stub_body, + is_stub_suite, + is_tuple_containing_walrus, + is_type_ignore_comment_string, + is_vararg, + is_walrus_assignment, + is_yield, + syms, + wrap_in_parentheses, ) -from black.nodes import is_name_token, is_lpar_token, is_rpar_token -from black.nodes import is_walrus_assignment, is_yield, is_vararg, is_multiline_string -from black.nodes import is_stub_suite, is_stub_body, is_atom_with_invisible_parens -from black.nodes import wrap_in_parentheses -from black.brackets import max_delimiter_priority_in_atom -from black.brackets import DOT_PRIORITY, COMMA_PRIORITY -from black.lines import Line, line_to_string, is_line_short_enough -from black.lines import can_omit_invisible_parens, can_be_split, append_leaves -from black.comments import generate_comments, list_comments, FMT_OFF from black.numerics import normalize_numeric_literal -from black.strings import get_string_prefix, fix_docstring -from black.strings import normalize_string_prefix, normalize_string_quotes -from black.trans import Transformer, CannotTransform, StringMerger, StringSplitter -from black.trans import StringParenWrapper, StringParenStripper, hug_power_op -from black.mode import Mode, Feature, Preview - -from blib2to3.pytree import Node, Leaf +from black.strings import ( + fix_docstring, + get_string_prefix, + normalize_string_prefix, + normalize_string_quotes, + normalize_unicode_escape_sequences, +) +from black.trans import ( + CannotTransform, + StringMerger, + StringParenStripper, + StringParenWrapper, + StringSplitter, + Transformer, + hug_power_op, +) from blib2to3.pgen2 import token - +from blib2to3.pytree import Leaf, Node # types LeafID = int @@ -52,8 +94,9 @@ class LineGenerator(Visitor[Line]): in ways that will no longer stringify to valid Python code on the tree. """ - def __init__(self, mode: Mode) -> None: + def __init__(self, mode: Mode, features: Collection[Feature]) -> None: self.mode = mode + self.features = features self.current_line: Line self.__post_init__() @@ -69,6 +112,17 @@ class LineGenerator(Visitor[Line]): self.current_line.depth += indent return # Line is empty, don't emit. Creating a new one unnecessary. + if ( + Preview.improved_async_statements_handling in self.mode + and len(self.current_line.leaves) == 1 + and is_async_stmt_or_funcdef(self.current_line.leaves[0]) + ): + # Special case for async def/for/with statements. `visit_async_stmt` + # adds an `ASYNC` leaf then visits the child def/for/with statement + # nodes. Line yields from those nodes shouldn't treat the former + # `ASYNC` leaf as a complete line. + return + complete_line = self.current_line self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent) yield complete_line @@ -77,7 +131,7 @@ class LineGenerator(Visitor[Line]): """Default `visit_*()` implementation. Recurses to children of `node`.""" if isinstance(node, Leaf): any_open_brackets = self.current_line.bracket_tracker.any_open_brackets() - for comment in generate_comments(node, preview=self.mode.preview): + for comment in generate_comments(node): if any_open_brackets: # any comment within brackets is subject to splitting self.current_line.append(comment) @@ -103,6 +157,22 @@ class LineGenerator(Visitor[Line]): self.current_line.append(node) yield from super().visit_default(node) + def visit_test(self, node: Node) -> Iterator[Line]: + """Visit an `x if y else z` test""" + + if Preview.parenthesize_conditional_expressions in self.mode: + already_parenthesized = ( + node.prev_sibling and node.prev_sibling.type == token.LPAR + ) + + if not already_parenthesized: + lpar = Leaf(token.LPAR, "") + rpar = Leaf(token.RPAR, "") + node.insert_child(0, lpar) + node.append_child(rpar) + + yield from self.visit_default(node) + def visit_INDENT(self, node: Leaf) -> Iterator[Line]: """Increase indentation level, maybe yield a line.""" # In blib2to3 INDENT never holds comments. @@ -137,26 +207,33 @@ class LineGenerator(Visitor[Line]): `parens` holds a set of string leaf values immediately after which invisible parens should be put. """ - normalize_invisible_parens(node, parens_after=parens, preview=self.mode.preview) + normalize_invisible_parens( + node, parens_after=parens, mode=self.mode, features=self.features + ) for child in node.children: if is_name_token(child) and child.value in keywords: yield from self.line() yield from self.visit(child) - def visit_funcdef(self, node: Node) -> Iterator[Line]: - """Visit function definition.""" - if Preview.annotation_parens not in self.mode: - yield from self.visit_stmt(node, keywords={"def"}, parens=set()) - else: - yield from self.line() + def visit_typeparams(self, node: Node) -> Iterator[Line]: + yield from self.visit_default(node) + node.children[0].prefix = "" + + def visit_typevartuple(self, node: Node) -> Iterator[Line]: + yield from self.visit_default(node) + node.children[1].prefix = "" - # Remove redundant brackets around return type annotation. - is_return_annotation = False - for child in node.children: - if child.type == token.RARROW: - is_return_annotation = True - elif is_return_annotation: + def visit_paramspec(self, node: Node) -> Iterator[Line]: + yield from self.visit_default(node) + node.children[1].prefix = "" + + def visit_dictsetmaker(self, node: Node) -> Iterator[Line]: + if Preview.wrap_long_dict_values_in_parens in self.mode: + for i, child in enumerate(node.children): + if i == 0: + continue + if node.children[i - 1].type == token.COLON: if child.type == syms.atom and child.children[0].type == token.LPAR: if maybe_make_parens_invisible_in_atom( child, @@ -166,14 +243,37 @@ class LineGenerator(Visitor[Line]): wrap_in_parentheses(node, child, visible=False) else: wrap_in_parentheses(node, child, visible=False) - is_return_annotation = False + yield from self.visit_default(node) - for child in node.children: - yield from self.visit(child) + def visit_funcdef(self, node: Node) -> Iterator[Line]: + """Visit function definition.""" + yield from self.line() + + # Remove redundant brackets around return type annotation. + is_return_annotation = False + for child in node.children: + if child.type == token.RARROW: + is_return_annotation = True + elif is_return_annotation: + if child.type == syms.atom and child.children[0].type == token.LPAR: + if maybe_make_parens_invisible_in_atom( + child, + parent=node, + remove_brackets_around_comma=False, + ): + wrap_in_parentheses(node, child, visible=False) + else: + wrap_in_parentheses(node, child, visible=False) + is_return_annotation = False + + for child in node.children: + yield from self.visit(child) def visit_match_case(self, node: Node) -> Iterator[Line]: """Visit either a match or case statement.""" - normalize_invisible_parens(node, parens_after=set(), preview=self.mode.preview) + normalize_invisible_parens( + node, parens_after=set(), mode=self.mode, features=self.features + ) yield from self.line() for child in node.children: @@ -220,12 +320,17 @@ class LineGenerator(Visitor[Line]): for child in children: yield from self.visit(child) - if child.type == token.ASYNC: + if child.type == token.ASYNC or child.type == STANDALONE_COMMENT: + # STANDALONE_COMMENT happens when `# fmt: skip` is applied on the async + # line. break internal_stmt = next(children) - for child in internal_stmt.children: - yield from self.visit(child) + if Preview.improved_async_statements_handling in self.mode: + yield from self.visit(internal_stmt) + else: + for child in internal_stmt.children: + yield from self.visit(child) def visit_decorators(self, node: Node) -> Iterator[Line]: """Visit decorators.""" @@ -253,8 +358,7 @@ class LineGenerator(Visitor[Line]): ): wrap_in_parentheses(node, leaf) - if Preview.remove_redundant_parens in self.mode: - remove_await_parens(node) + remove_await_parens(node) yield from self.visit_default(node) @@ -290,10 +394,23 @@ class LineGenerator(Visitor[Line]): yield from self.visit_default(node) def visit_STRING(self, leaf: Leaf) -> Iterator[Line]: + if Preview.hex_codes_in_unicode_sequences in self.mode: + normalize_unicode_escape_sequences(leaf) + if is_docstring(leaf) and "\\\n" not in leaf.value: # We're ignoring docstrings with backslash newline escapes because changing # indentation of those changes the AST representation of the code. - docstring = normalize_string_prefix(leaf.value) + if self.mode.string_normalization: + docstring = normalize_string_prefix(leaf.value) + # visit_default() does handle string normalization for us, but + # since this method acts differently depending on quote style (ex. + # see padding logic below), there's a possibility for unstable + # formatting as visit_default() is called *after*. To avoid a + # situation where this function formats a docstring differently on + # the second pass, normalize it early. + docstring = normalize_string_quotes(docstring) + else: + docstring = leaf.value prefix = get_string_prefix(docstring) docstring = docstring[len(prefix) :] # Remove the prefix quote_char = docstring[0] @@ -305,13 +422,14 @@ class LineGenerator(Visitor[Line]): quote_len = 1 if docstring[1] != quote_char else 3 docstring = docstring[quote_len:-quote_len] docstring_started_empty = not docstring + indent = " " * 4 * self.current_line.depth if is_multiline_string(leaf): - indent = " " * 4 * self.current_line.depth docstring = fix_docstring(docstring, indent) else: docstring = docstring.strip() + has_trailing_backslash = False if docstring: # Add some padding if the docstring starts / ends with a quote mark. if docstring[0] == quote_char: @@ -324,12 +442,37 @@ class LineGenerator(Visitor[Line]): # Odd number of tailing backslashes, add some padding to # avoid escaping the closing string quote. docstring += " " + has_trailing_backslash = True elif not docstring_started_empty: docstring = " " # We could enforce triple quotes at this point. quote = quote_char * quote_len - leaf.value = prefix + quote + docstring + quote + + # It's invalid to put closing single-character quotes on a new line. + if self.mode and quote_len == 3: + # We need to find the length of the last line of the docstring + # to find if we can add the closing quotes to the line without + # exceeding the maximum line length. + # If docstring is one line, we don't put the closing quotes on a + # separate line because it looks ugly (#3320). + lines = docstring.splitlines() + last_line_length = len(lines[-1]) if docstring else 0 + + # If adding closing quotes would cause the last line to exceed + # the maximum line length then put a line break before the + # closing quotes + if ( + len(lines) > 1 + and last_line_length + quote_len > self.mode.line_length + and len(indent) + quote_len <= self.mode.line_length + and not has_trailing_backslash + ): + leaf.value = prefix + quote + docstring + "\n" + indent + quote + else: + leaf.value = prefix + quote + docstring + quote + else: + leaf.value = prefix + quote + docstring + quote yield from self.visit_default(leaf) @@ -348,14 +491,8 @@ class LineGenerator(Visitor[Line]): self.visit_try_stmt = partial( v, keywords={"try", "except", "else", "finally"}, parens=Ø ) - if self.mode.preview: - self.visit_except_clause = partial( - v, keywords={"except"}, parens={"except"} - ) - self.visit_with_stmt = partial(v, keywords={"with"}, parens={"with"}) - else: - self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø) - self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø) + self.visit_except_clause = partial(v, keywords={"except"}, parens={"except"}) + self.visit_with_stmt = partial(v, keywords={"with"}, parens={"with"}) self.visit_classdef = partial(v, keywords={"class"}, parens=Ø) self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS) self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"}) @@ -397,7 +534,7 @@ def transform_line( and not line.should_split_rhs and not line.magic_trailing_comma and ( - is_line_short_enough(line, line_length=mode.line_length, line_str=line_str) + is_line_short_enough(line, mode=mode, line_str=line_str) or line.contains_unsplittable_type_ignore() ) and not (line.inside_brackets and line.contains_standalone_comments()) @@ -412,7 +549,7 @@ def transform_line( else: def _rhs( - self: object, line: Line, features: Collection[Feature] + self: object, line: Line, features: Collection[Feature], mode: Mode ) -> Iterator[Line]: """Wraps calls to `right_hand_split`. @@ -421,14 +558,12 @@ def transform_line( bracket pair instead. """ for omit in generate_trailers_to_omit(line, mode.line_length): - lines = list( - right_hand_split(line, mode.line_length, features, omit=omit) - ) + lines = list(right_hand_split(line, mode, features, omit=omit)) # Note: this check is only able to figure out if the first line of the # *current* transformation fits in the line length. This is true only # for simple cases. All others require running more transforms via # `transform_line()`. This check doesn't know if those would succeed. - if is_line_short_enough(lines[0], line_length=mode.line_length): + if is_line_short_enough(lines[0], mode=mode): yield from lines return @@ -436,9 +571,7 @@ def transform_line( # This mostly happens to multiline strings that are by definition # reported as not fitting a single line, as well as lines that contain # trailing commas (those have to be exploded). - yield from right_hand_split( - line, line_length=mode.line_length, features=features - ) + yield from right_hand_split(line, mode, features=features) # HACK: nested functions (like _rhs) compiled by mypyc don't retain their # __name__ attribute which is needed in `run_transformer` further down. @@ -490,7 +623,15 @@ def transform_line( yield line -def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]: +class _BracketSplitComponent(Enum): + head = auto() + body = auto() + tail = auto() + + +def left_hand_split( + line: Line, _features: Collection[Feature], mode: Mode +) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. @@ -520,9 +661,15 @@ def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator if not matching_bracket: raise CannotSplit("No brackets found") - head = bracket_split_build_line(head_leaves, line, matching_bracket) - body = bracket_split_build_line(body_leaves, line, matching_bracket, is_body=True) - tail = bracket_split_build_line(tail_leaves, line, matching_bracket) + head = bracket_split_build_line( + head_leaves, line, matching_bracket, component=_BracketSplitComponent.head + ) + body = bracket_split_build_line( + body_leaves, line, matching_bracket, component=_BracketSplitComponent.body + ) + tail = bracket_split_build_line( + tail_leaves, line, matching_bracket, component=_BracketSplitComponent.tail + ) bracket_split_succeeded_or_raise(head, body, tail) for result in (head, body, tail): if result: @@ -531,7 +678,7 @@ def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator def right_hand_split( line: Line, - line_length: int, + mode: Mode, features: Collection[Feature] = (), omit: Collection[LeafID] = (), ) -> Iterator[Line]: @@ -543,6 +690,22 @@ def right_hand_split( Note: running this function modifies `bracket_depth` on the leaves of `line`. """ + rhs_result = _first_right_hand_split(line, omit=omit) + yield from _maybe_split_omitting_optional_parens( + rhs_result, line, mode, features=features, omit=omit + ) + + +def _first_right_hand_split( + line: Line, + omit: Collection[LeafID] = (), +) -> RHSResult: + """Split the line into head, body, tail starting with the last bracket pair. + + Note: this function should not have side effects. It's relied upon by + _maybe_split_omitting_optional_parens to get an opinion whether to prefer + splitting on the right side of an assignment statement. + """ tail_leaves: List[Leaf] = [] body_leaves: List[Leaf] = [] head_leaves: List[Leaf] = [] @@ -568,41 +731,82 @@ def right_hand_split( tail_leaves.reverse() body_leaves.reverse() head_leaves.reverse() - head = bracket_split_build_line(head_leaves, line, opening_bracket) - body = bracket_split_build_line(body_leaves, line, opening_bracket, is_body=True) - tail = bracket_split_build_line(tail_leaves, line, opening_bracket) + head = bracket_split_build_line( + head_leaves, line, opening_bracket, component=_BracketSplitComponent.head + ) + body = bracket_split_build_line( + body_leaves, line, opening_bracket, component=_BracketSplitComponent.body + ) + tail = bracket_split_build_line( + tail_leaves, line, opening_bracket, component=_BracketSplitComponent.tail + ) bracket_split_succeeded_or_raise(head, body, tail) + return RHSResult(head, body, tail, opening_bracket, closing_bracket) + + +def _maybe_split_omitting_optional_parens( + rhs: RHSResult, + line: Line, + mode: Mode, + features: Collection[Feature] = (), + omit: Collection[LeafID] = (), +) -> Iterator[Line]: if ( Feature.FORCE_OPTIONAL_PARENTHESES not in features # the opening bracket is an optional paren - and opening_bracket.type == token.LPAR - and not opening_bracket.value + and rhs.opening_bracket.type == token.LPAR + and not rhs.opening_bracket.value # the closing bracket is an optional paren - and closing_bracket.type == token.RPAR - and not closing_bracket.value + and rhs.closing_bracket.type == token.RPAR + and not rhs.closing_bracket.value # it's not an import (optional parens are the only thing we can split on # in this case; attempting a split without them is a waste of time) and not line.is_import # there are no standalone comments in the body - and not body.contains_standalone_comments(0) + and not rhs.body.contains_standalone_comments(0) # and we can actually remove the parens - and can_omit_invisible_parens(body, line_length) + and can_omit_invisible_parens(rhs, mode.line_length) ): - omit = {id(closing_bracket), *omit} + omit = {id(rhs.closing_bracket), *omit} try: - yield from right_hand_split(line, line_length, features=features, omit=omit) - return + # The RHSResult Omitting Optional Parens. + rhs_oop = _first_right_hand_split(line, omit=omit) + if not ( + Preview.prefer_splitting_right_hand_side_of_assignments in line.mode + # the split is right after `=` + and len(rhs.head.leaves) >= 2 + and rhs.head.leaves[-2].type == token.EQUAL + # the left side of assignment contains brackets + and any(leaf.type in BRACKETS for leaf in rhs.head.leaves[:-1]) + # the left side of assignment is short enough (the -1 is for the ending + # optional paren) + and is_line_short_enough( + rhs.head, mode=replace(mode, line_length=mode.line_length - 1) + ) + # the left side of assignment won't explode further because of magic + # trailing comma + and rhs.head.magic_trailing_comma is None + # the split by omitting optional parens isn't preferred by some other + # reason + and not _prefer_split_rhs_oop(rhs_oop, mode) + ): + yield from _maybe_split_omitting_optional_parens( + rhs_oop, line, mode, features=features, omit=omit + ) + return except CannotSplit as e: if not ( - can_be_split(body) - or is_line_short_enough(body, line_length=line_length) + can_be_split(rhs.body) or is_line_short_enough(rhs.body, mode=mode) ): raise CannotSplit( "Splitting failed, body is still too long and can't be split." ) from e - elif head.contains_multiline_strings() or tail.contains_multiline_strings(): + elif ( + rhs.head.contains_multiline_strings() + or rhs.tail.contains_multiline_strings() + ): raise CannotSplit( "The current optional pair of parentheses is bound to fail to" " satisfy the splitting algorithm because the head or the tail" @@ -610,13 +814,42 @@ def right_hand_split( " line." ) from e - ensure_visible(opening_bracket) - ensure_visible(closing_bracket) - for result in (head, body, tail): + ensure_visible(rhs.opening_bracket) + ensure_visible(rhs.closing_bracket) + for result in (rhs.head, rhs.body, rhs.tail): if result: yield result +def _prefer_split_rhs_oop(rhs_oop: RHSResult, mode: Mode) -> bool: + """ + Returns whether we should prefer the result from a split omitting optional parens. + """ + has_closing_bracket_after_assign = False + for leaf in reversed(rhs_oop.head.leaves): + if leaf.type == token.EQUAL: + break + if leaf.type in CLOSING_BRACKETS: + has_closing_bracket_after_assign = True + break + return ( + # contains matching brackets after the `=` (done by checking there is a + # closing bracket) + has_closing_bracket_after_assign + or ( + # the split is actually from inside the optional parens (done by checking + # the first line still contains the `=`) + any(leaf.type == token.EQUAL for leaf in rhs_oop.head.leaves) + # the first line is short enough + and is_line_short_enough(rhs_oop.head, mode=mode) + ) + # contains unsplittable type ignore + or rhs_oop.head.contains_unsplittable_type_ignore() + or rhs_oop.body.contains_unsplittable_type_ignore() + or rhs_oop.tail.contains_unsplittable_type_ignore() + ) + + def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None: """Raise :exc:`CannotSplit` if the last left- or right-hand split failed. @@ -644,15 +877,23 @@ def bracket_split_succeeded_or_raise(head: Line, body: Line, tail: Line) -> None def bracket_split_build_line( - leaves: List[Leaf], original: Line, opening_bracket: Leaf, *, is_body: bool = False + leaves: List[Leaf], + original: Line, + opening_bracket: Leaf, + *, + component: _BracketSplitComponent, ) -> Line: """Return a new line with given `leaves` and respective comments from `original`. - If `is_body` is True, the result line is one-indented inside brackets and as such - has its first leaf's prefix normalized and a trailing comma added when expected. + If it's the head component, brackets will be tracked so trailing commas are + respected. + + If it's the body component, the result line is one-indented inside brackets and as + such has its first leaf's prefix normalized and a trailing comma added when + expected. """ result = Line(mode=original.mode, depth=original.depth) - if is_body: + if component is _BracketSplitComponent.body: result.inside_brackets = True result.depth += 1 if leaves: @@ -678,6 +919,13 @@ def bracket_split_build_line( ) if isinstance(node, Node) and isinstance(node.prev_sibling, Leaf) ) + # Except the false negatives above for PEP 604 unions where we + # can't add the comma. + and not ( + leaves[0].parent + and leaves[0].parent.next_sibling + and leaves[0].parent.next_sibling.type == token.VBAR + ) ) if original.is_import or no_commas: @@ -690,12 +938,21 @@ def bracket_split_build_line( leaves.insert(i + 1, new_comma) break + leaves_to_track: Set[LeafID] = set() + if component is _BracketSplitComponent.head: + leaves_to_track = get_leaves_inside_matching_brackets(leaves) # Populate the line for leaf in leaves: - result.append(leaf, preformatted=True) + result.append( + leaf, + preformatted=True, + track_bracket=id(leaf) in leaves_to_track, + ) for comment_after in original.comments_after(leaf): result.append(comment_after, preformatted=True) - if is_body and should_split_line(result, opening_bracket): + if component is _BracketSplitComponent.body and should_split_line( + result, opening_bracket + ): result.should_split_rhs = True return result @@ -707,16 +964,39 @@ def dont_increase_indentation(split_func: Transformer) -> Transformer: """ @wraps(split_func) - def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]: - for split_line in split_func(line, features): + def split_wrapper( + line: Line, features: Collection[Feature], mode: Mode + ) -> Iterator[Line]: + for split_line in split_func(line, features, mode): normalize_prefix(split_line.leaves[0], inside_brackets=True) yield split_line return split_wrapper +def _get_last_non_comment_leaf(line: Line) -> Optional[int]: + for leaf_idx in range(len(line.leaves) - 1, 0, -1): + if line.leaves[leaf_idx].type != STANDALONE_COMMENT: + return leaf_idx + return None + + +def _safe_add_trailing_comma(safe: bool, delimiter_priority: int, line: Line) -> Line: + if ( + safe + and delimiter_priority == COMMA_PRIORITY + and line.leaves[-1].type != token.COMMA + and line.leaves[-1].type != STANDALONE_COMMENT + ): + new_comma = Leaf(token.COMMA, ",") + line.append(new_comma) + return line + + @dont_increase_indentation -def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]: +def delimiter_split( + line: Line, features: Collection[Feature], mode: Mode +) -> Iterator[Line]: """Split according to delimiters of the highest priority. If the appropriate Features are given, the split will add trailing commas @@ -756,7 +1036,8 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ ) current_line.append(leaf) - for leaf in line.leaves: + last_non_comment_leaf = _get_last_non_comment_leaf(line) + for leaf_idx, leaf in enumerate(line.leaves): yield from append_to_line(leaf) for comment_after in line.comments_after(leaf): @@ -773,6 +1054,15 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features ) + if ( + Preview.add_trailing_comma_consistently in mode + and last_leaf.type == STANDALONE_COMMENT + and leaf_idx == last_non_comment_leaf + ): + current_line = _safe_add_trailing_comma( + trailing_comma_safe, delimiter_priority, current_line + ) + leaf_priority = bt.delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: yield current_line @@ -781,20 +1071,15 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets ) if current_line: - if ( - trailing_comma_safe - and delimiter_priority == COMMA_PRIORITY - and current_line.leaves[-1].type != token.COMMA - and current_line.leaves[-1].type != STANDALONE_COMMENT - ): - new_comma = Leaf(token.COMMA, ",") - current_line.append(new_comma) + current_line = _safe_add_trailing_comma( + trailing_comma_safe, delimiter_priority, current_line + ) yield current_line @dont_increase_indentation def standalone_comment_split( - line: Line, features: Collection[Feature] = () + line: Line, features: Collection[Feature], mode: Mode ) -> Iterator[Line]: """Split standalone comments from the rest of the line.""" if not line.contains_standalone_comments(0): @@ -846,7 +1131,7 @@ def normalize_prefix(leaf: Leaf, *, inside_brackets: bool) -> None: def normalize_invisible_parens( - node: Node, parens_after: Set[str], *, preview: bool + node: Node, parens_after: Set[str], *, mode: Mode, features: Collection[Feature] ) -> None: """Make existing optional parentheses invisible or create new ones. @@ -856,17 +1141,24 @@ def normalize_invisible_parens( Standardizes on visible parentheses for single-element tuples, and keeps existing visible parentheses for other tuples and generator expressions. """ - for pc in list_comments(node.prefix, is_endmarker=False, preview=preview): + for pc in list_comments(node.prefix, is_endmarker=False): if pc.value in FMT_OFF: # This `node` has a prefix with `# fmt: off`, don't mess with parens. return + + # The multiple context managers grammar has a different pattern, thus this is + # separate from the for-loop below. This possibly wraps them in invisible parens, + # and later will be removed in remove_with_parens when needed. + if node.type == syms.with_stmt: + _maybe_wrap_cms_in_parens(node, mode, features) + check_lpar = False for index, child in enumerate(list(node.children)): # Fixes a bug where invisible parens are not properly stripped from # assignment statements that contain type annotations. if isinstance(child, Node) and child.type == syms.annassign: normalize_invisible_parens( - child, parens_after=parens_after, preview=preview + child, parens_after=parens_after, mode=mode, features=features ) # Add parentheses around long tuple unpacking in assignments. @@ -879,8 +1171,7 @@ def normalize_invisible_parens( if check_lpar: if ( - preview - and child.type == syms.atom + child.type == syms.atom and node.type == syms.for_stmt and isinstance(child.prev_sibling, Leaf) and child.prev_sibling.type == token.NAME @@ -892,7 +1183,7 @@ def normalize_invisible_parens( remove_brackets_around_comma=True, ): wrap_in_parentheses(node, child, visible=False) - elif preview and isinstance(child, Node) and node.type == syms.with_stmt: + elif isinstance(child, Node) and node.type == syms.with_stmt: remove_with_parens(child, node) elif child.type == syms.atom: if maybe_make_parens_invisible_in_atom( @@ -903,29 +1194,42 @@ def normalize_invisible_parens( elif is_one_tuple(child): wrap_in_parentheses(node, child, visible=True) elif node.type == syms.import_from: - # "import from" nodes store parentheses directly as part of - # the statement - if is_lpar_token(child): - assert is_rpar_token(node.children[-1]) - # make parentheses invisible - child.value = "" - node.children[-1].value = "" - elif child.type != token.STAR: - # insert invisible parentheses - node.insert_child(index, Leaf(token.LPAR, "")) - node.append_child(Leaf(token.RPAR, "")) + _normalize_import_from(node, child, index) break + elif ( + index == 1 + and child.type == token.STAR + and node.type == syms.except_clause + ): + # In except* (PEP 654), the star is actually part of + # of the keyword. So we need to skip the insertion of + # invisible parentheses to work more precisely. + continue elif not (isinstance(child, Leaf) and is_multiline_string(child)): wrap_in_parentheses(node, child, visible=False) - comma_check = child.type == token.COMMA if preview else False + comma_check = child.type == token.COMMA check_lpar = isinstance(child, Leaf) and ( child.value in parens_after or comma_check ) +def _normalize_import_from(parent: Node, child: LN, index: int) -> None: + # "import from" nodes store parentheses directly as part of + # the statement + if is_lpar_token(child): + assert is_rpar_token(parent.children[-1]) + # make parentheses invisible + child.value = "" + parent.children[-1].value = "" + elif child.type != token.STAR: + # insert invisible parentheses + parent.insert_child(index, Leaf(token.LPAR, "")) + parent.append_child(Leaf(token.RPAR, "")) + + def remove_await_parens(node: Node) -> None: if node.children[0].type == token.AWAIT and len(node.children) > 1: if ( @@ -948,18 +1252,62 @@ def remove_await_parens(node: Node) -> None: # N.B. We've still removed any redundant nested brackets though :) opening_bracket = cast(Leaf, node.children[1].children[0]) closing_bracket = cast(Leaf, node.children[1].children[-1]) - bracket_contents = cast(Node, node.children[1].children[1]) - if bracket_contents.type != syms.power: - ensure_visible(opening_bracket) - ensure_visible(closing_bracket) - elif ( - bracket_contents.type == syms.power - and bracket_contents.children[0].type == token.AWAIT - ): - ensure_visible(opening_bracket) - ensure_visible(closing_bracket) - # If we are in a nested await then recurse down. - remove_await_parens(bracket_contents) + bracket_contents = node.children[1].children[1] + if isinstance(bracket_contents, Node): + if bracket_contents.type != syms.power: + ensure_visible(opening_bracket) + ensure_visible(closing_bracket) + elif ( + bracket_contents.type == syms.power + and bracket_contents.children[0].type == token.AWAIT + ): + ensure_visible(opening_bracket) + ensure_visible(closing_bracket) + # If we are in a nested await then recurse down. + remove_await_parens(bracket_contents) + + +def _maybe_wrap_cms_in_parens( + node: Node, mode: Mode, features: Collection[Feature] +) -> None: + """When enabled and safe, wrap the multiple context managers in invisible parens. + + It is only safe when `features` contain Feature.PARENTHESIZED_CONTEXT_MANAGERS. + """ + if ( + Feature.PARENTHESIZED_CONTEXT_MANAGERS not in features + or Preview.wrap_multiple_context_managers_in_parens not in mode + or len(node.children) <= 2 + # If it's an atom, it's already wrapped in parens. + or node.children[1].type == syms.atom + ): + return + colon_index: Optional[int] = None + for i in range(2, len(node.children)): + if node.children[i].type == token.COLON: + colon_index = i + break + if colon_index is not None: + lpar = Leaf(token.LPAR, "") + rpar = Leaf(token.RPAR, "") + context_managers = node.children[1:colon_index] + for child in context_managers: + child.remove() + # After wrapping, the with_stmt will look like this: + # with_stmt + # NAME 'with' + # atom + # LPAR '' + # testlist_gexp + # ... <-- context_managers + # /testlist_gexp + # RPAR '' + # /atom + # COLON ':' + new_child = Node( + syms.atom, [lpar, Node(syms.testlist_gexp, context_managers), rpar] + ) + node.insert_child(1, new_child) def remove_with_parens(node: Node, parent: Node) -> None: @@ -1027,6 +1375,7 @@ def maybe_make_parens_invisible_in_atom( not remove_brackets_around_comma and max_delimiter_priority_in_atom(node) >= COMMA_PRIORITY ) + or is_tuple_containing_walrus(node) ): return False @@ -1036,9 +1385,13 @@ def maybe_make_parens_invisible_in_atom( syms.expr_stmt, syms.assert_stmt, syms.return_stmt, + syms.except_clause, + syms.funcdef, + syms.with_stmt, # these ones aren't useful to end users, but they do please fuzzers syms.for_stmt, syms.del_stmt, + syms.for_stmt, ]: return False @@ -1047,8 +1400,13 @@ def maybe_make_parens_invisible_in_atom( if is_lpar_token(first) and is_rpar_token(last): middle = node.children[1] # make parentheses invisible - first.value = "" - last.value = "" + if ( + # If the prefix of `middle` includes a type comment with + # ignore annotation, then we do not remove the parentheses + not is_type_ignore_comment_string(middle.prefix.strip()) + ): + first.value = "" + last.value = "" maybe_make_parens_invisible_in_atom( middle, parent=parent, @@ -1179,20 +1537,22 @@ def run_transformer( if not line_str: line_str = line_to_string(line) result: List[Line] = [] - for transformed_line in transform(line, features): + for transformed_line in transform(line, features, mode): if str(transformed_line).strip("\n") == line_str: raise CannotTransform("Line transformer returned an unchanged result") result.extend(transform_line(transformed_line, mode=mode, features=features)) + features_set = set(features) if ( - transform.__class__.__name__ != "rhs" + Feature.FORCE_OPTIONAL_PARENTHESES in features_set + or transform.__class__.__name__ != "rhs" or not line.bracket_tracker.invisible or any(bracket.value for bracket in line.bracket_tracker.invisible) or line.contains_multiline_strings() or result[0].contains_uncollapsable_type_comments() or result[0].contains_unsplittable_type_ignore() - or is_line_short_enough(result[0], line_length=mode.line_length) + or is_line_short_enough(result[0], mode=mode) # If any leaves have no parents (which _can_ occur since # `transform(line)` potentially destroys the line's underlying node # structure), then we can't proceed. Doing so would cause the below @@ -1203,12 +1563,10 @@ def run_transformer( line_copy = line.clone() append_leaves(line_copy, line, line.leaves) - features_fop = set(features) | {Feature.FORCE_OPTIONAL_PARENTHESES} + features_fop = features_set | {Feature.FORCE_OPTIONAL_PARENTHESES} second_opinion = run_transformer( line_copy, transform, mode, features_fop, line_str=line_str ) - if all( - is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion - ): + if all(is_line_short_enough(ln, mode=mode) for ln in second_opinion): result = second_opinion return result