X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/91e1e1328aa0a11ef50017316ff97149886e1b05..bf7a16254ec96b084a6caf3d435ec18f0f245cc7:/src/black/linegen.py diff --git a/src/black/linegen.py b/src/black/linegen.py index 2f50257..b6b83da 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -2,7 +2,7 @@ Generating lines of code. """ import sys -from dataclasses import dataclass +from dataclasses import replace from enum import Enum, auto from functools import partial, wraps from typing import Collection, Iterator, List, Optional, Set, Union, cast @@ -16,6 +16,7 @@ from black.brackets import ( from black.comments import FMT_OFF, generate_comments, list_comments from black.lines import ( Line, + RHSResult, append_leaves, can_be_split, can_omit_invisible_parens, @@ -35,6 +36,7 @@ from black.nodes import ( Visitor, ensure_visible, is_arith_like, + is_async_stmt_or_funcdef, is_atom_with_invisible_parens, is_docstring, is_empty_tuple, @@ -59,6 +61,7 @@ from black.strings import ( get_string_prefix, normalize_string_prefix, normalize_string_quotes, + normalize_unicode_escape_sequences, ) from black.trans import ( CannotTransform, @@ -108,6 +111,17 @@ class LineGenerator(Visitor[Line]): self.current_line.depth += indent return # Line is empty, don't emit. Creating a new one unnecessary. + if ( + Preview.improved_async_statements_handling in self.mode + and len(self.current_line.leaves) == 1 + and is_async_stmt_or_funcdef(self.current_line.leaves[0]) + ): + # Special case for async def/for/with statements. `visit_async_stmt` + # adds an `ASYNC` leaf then visits the child def/for/with statement + # nodes. Line yields from those nodes shouldn't treat the former + # `ASYNC` leaf as a complete line. + return + complete_line = self.current_line self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent) yield complete_line @@ -116,7 +130,7 @@ class LineGenerator(Visitor[Line]): """Default `visit_*()` implementation. Recurses to children of `node`.""" if isinstance(node, Leaf): any_open_brackets = self.current_line.bracket_tracker.any_open_brackets() - for comment in generate_comments(node, preview=self.mode.preview): + for comment in generate_comments(node): if any_open_brackets: # any comment within brackets is subject to splitting self.current_line.append(comment) @@ -220,30 +234,27 @@ class LineGenerator(Visitor[Line]): def visit_funcdef(self, node: Node) -> Iterator[Line]: """Visit function definition.""" - if Preview.annotation_parens not in self.mode: - yield from self.visit_stmt(node, keywords={"def"}, parens=set()) - else: - yield from self.line() + yield from self.line() - # Remove redundant brackets around return type annotation. - is_return_annotation = False - for child in node.children: - if child.type == token.RARROW: - is_return_annotation = True - elif is_return_annotation: - if child.type == syms.atom and child.children[0].type == token.LPAR: - if maybe_make_parens_invisible_in_atom( - child, - parent=node, - remove_brackets_around_comma=False, - ): - wrap_in_parentheses(node, child, visible=False) - else: + # Remove redundant brackets around return type annotation. + is_return_annotation = False + for child in node.children: + if child.type == token.RARROW: + is_return_annotation = True + elif is_return_annotation: + if child.type == syms.atom and child.children[0].type == token.LPAR: + if maybe_make_parens_invisible_in_atom( + child, + parent=node, + remove_brackets_around_comma=False, + ): wrap_in_parentheses(node, child, visible=False) - is_return_annotation = False + else: + wrap_in_parentheses(node, child, visible=False) + is_return_annotation = False - for child in node.children: - yield from self.visit(child) + for child in node.children: + yield from self.visit(child) def visit_match_case(self, node: Node) -> Iterator[Line]: """Visit either a match or case statement.""" @@ -302,8 +313,11 @@ class LineGenerator(Visitor[Line]): break internal_stmt = next(children) - for child in internal_stmt.children: - yield from self.visit(child) + if Preview.improved_async_statements_handling in self.mode: + yield from self.visit(internal_stmt) + else: + for child in internal_stmt.children: + yield from self.visit(child) def visit_decorators(self, node: Node) -> Iterator[Line]: """Visit decorators.""" @@ -331,8 +345,7 @@ class LineGenerator(Visitor[Line]): ): wrap_in_parentheses(node, leaf) - if Preview.remove_redundant_parens in self.mode: - remove_await_parens(node) + remove_await_parens(node) yield from self.visit_default(node) @@ -368,27 +381,23 @@ class LineGenerator(Visitor[Line]): yield from self.visit_default(node) def visit_STRING(self, leaf: Leaf) -> Iterator[Line]: + if Preview.hex_codes_in_unicode_sequences in self.mode: + normalize_unicode_escape_sequences(leaf) + if is_docstring(leaf) and "\\\n" not in leaf.value: # We're ignoring docstrings with backslash newline escapes because changing # indentation of those changes the AST representation of the code. - if Preview.normalize_docstring_quotes_and_prefixes_properly in self.mode: - # There was a bug where --skip-string-normalization wouldn't stop us - # from normalizing docstring prefixes. To maintain stability, we can - # only address this buggy behaviour while the preview style is enabled. - if self.mode.string_normalization: - docstring = normalize_string_prefix(leaf.value) - # visit_default() does handle string normalization for us, but - # since this method acts differently depending on quote style (ex. - # see padding logic below), there's a possibility for unstable - # formatting as visit_default() is called *after*. To avoid a - # situation where this function formats a docstring differently on - # the second pass, normalize it early. - docstring = normalize_string_quotes(docstring) - else: - docstring = leaf.value - else: - # ... otherwise, we'll keep the buggy behaviour >.< + if self.mode.string_normalization: docstring = normalize_string_prefix(leaf.value) + # visit_default() does handle string normalization for us, but + # since this method acts differently depending on quote style (ex. + # see padding logic below), there's a possibility for unstable + # formatting as visit_default() is called *after*. To avoid a + # situation where this function formats a docstring differently on + # the second pass, normalize it early. + docstring = normalize_string_quotes(docstring) + else: + docstring = leaf.value prefix = get_string_prefix(docstring) docstring = docstring[len(prefix) :] # Remove the prefix quote_char = docstring[0] @@ -428,7 +437,7 @@ class LineGenerator(Visitor[Line]): quote = quote_char * quote_len # It's invalid to put closing single-character quotes on a new line. - if Preview.long_docstring_quotes_on_newline in self.mode and quote_len == 3: + if self.mode and quote_len == 3: # We need to find the length of the last line of the docstring # to find if we can add the closing quotes to the line without # exceeding the maximum line length. @@ -469,14 +478,8 @@ class LineGenerator(Visitor[Line]): self.visit_try_stmt = partial( v, keywords={"try", "except", "else", "finally"}, parens=Ø ) - if self.mode.preview: - self.visit_except_clause = partial( - v, keywords={"except"}, parens={"except"} - ) - self.visit_with_stmt = partial(v, keywords={"with"}, parens={"with"}) - else: - self.visit_except_clause = partial(v, keywords={"except"}, parens=Ø) - self.visit_with_stmt = partial(v, keywords={"with"}, parens=Ø) + self.visit_except_clause = partial(v, keywords={"except"}, parens={"except"}) + self.visit_with_stmt = partial(v, keywords={"with"}, parens={"with"}) self.visit_classdef = partial(v, keywords={"class"}, parens=Ø) self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS) self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"}) @@ -518,7 +521,7 @@ def transform_line( and not line.should_split_rhs and not line.magic_trailing_comma and ( - is_line_short_enough(line, line_length=mode.line_length, line_str=line_str) + is_line_short_enough(line, mode=mode, line_str=line_str) or line.contains_unsplittable_type_ignore() ) and not (line.inside_brackets and line.contains_standalone_comments()) @@ -533,7 +536,7 @@ def transform_line( else: def _rhs( - self: object, line: Line, features: Collection[Feature] + self: object, line: Line, features: Collection[Feature], mode: Mode ) -> Iterator[Line]: """Wraps calls to `right_hand_split`. @@ -542,14 +545,12 @@ def transform_line( bracket pair instead. """ for omit in generate_trailers_to_omit(line, mode.line_length): - lines = list( - right_hand_split(line, mode.line_length, features, omit=omit) - ) + lines = list(right_hand_split(line, mode, features, omit=omit)) # Note: this check is only able to figure out if the first line of the # *current* transformation fits in the line length. This is true only # for simple cases. All others require running more transforms via # `transform_line()`. This check doesn't know if those would succeed. - if is_line_short_enough(lines[0], line_length=mode.line_length): + if is_line_short_enough(lines[0], mode=mode): yield from lines return @@ -557,9 +558,7 @@ def transform_line( # This mostly happens to multiline strings that are by definition # reported as not fitting a single line, as well as lines that contain # trailing commas (those have to be exploded). - yield from right_hand_split( - line, line_length=mode.line_length, features=features - ) + yield from right_hand_split(line, mode, features=features) # HACK: nested functions (like _rhs) compiled by mypyc don't retain their # __name__ attribute which is needed in `run_transformer` further down. @@ -617,7 +616,9 @@ class _BracketSplitComponent(Enum): tail = auto() -def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator[Line]: +def left_hand_split( + line: Line, _features: Collection[Feature], mode: Mode +) -> Iterator[Line]: """Split line into many lines, starting with the first matching bracket pair. Note: this usually looks weird, only use this for function definitions. @@ -662,20 +663,9 @@ def left_hand_split(line: Line, _features: Collection[Feature] = ()) -> Iterator yield result -@dataclass -class _RHSResult: - """Intermediate split result from a right hand split.""" - - head: Line - body: Line - tail: Line - opening_bracket: Leaf - closing_bracket: Leaf - - def right_hand_split( line: Line, - line_length: int, + mode: Mode, features: Collection[Feature] = (), omit: Collection[LeafID] = (), ) -> Iterator[Line]: @@ -689,14 +679,14 @@ def right_hand_split( """ rhs_result = _first_right_hand_split(line, omit=omit) yield from _maybe_split_omitting_optional_parens( - rhs_result, line, line_length, features=features, omit=omit + rhs_result, line, mode, features=features, omit=omit ) def _first_right_hand_split( line: Line, omit: Collection[LeafID] = (), -) -> _RHSResult: +) -> RHSResult: """Split the line into head, body, tail starting with the last bracket pair. Note: this function should not have side effects. It's relied upon by @@ -738,13 +728,13 @@ def _first_right_hand_split( tail_leaves, line, opening_bracket, component=_BracketSplitComponent.tail ) bracket_split_succeeded_or_raise(head, body, tail) - return _RHSResult(head, body, tail, opening_bracket, closing_bracket) + return RHSResult(head, body, tail, opening_bracket, closing_bracket) def _maybe_split_omitting_optional_parens( - rhs: _RHSResult, + rhs: RHSResult, line: Line, - line_length: int, + mode: Mode, features: Collection[Feature] = (), omit: Collection[LeafID] = (), ) -> Iterator[Line]: @@ -762,38 +752,39 @@ def _maybe_split_omitting_optional_parens( # there are no standalone comments in the body and not rhs.body.contains_standalone_comments(0) # and we can actually remove the parens - and can_omit_invisible_parens(rhs.body, line_length) + and can_omit_invisible_parens(rhs, mode.line_length) ): omit = {id(rhs.closing_bracket), *omit} try: - # The _RHSResult Omitting Optional Parens. + # The RHSResult Omitting Optional Parens. rhs_oop = _first_right_hand_split(line, omit=omit) if not ( Preview.prefer_splitting_right_hand_side_of_assignments in line.mode # the split is right after `=` and len(rhs.head.leaves) >= 2 and rhs.head.leaves[-2].type == token.EQUAL - # the left side of assignement contains brackets + # the left side of assignment contains brackets and any(leaf.type in BRACKETS for leaf in rhs.head.leaves[:-1]) # the left side of assignment is short enough (the -1 is for the ending # optional paren) - and is_line_short_enough(rhs.head, line_length=line_length - 1) + and is_line_short_enough( + rhs.head, mode=replace(mode, line_length=mode.line_length - 1) + ) # the left side of assignment won't explode further because of magic # trailing comma and rhs.head.magic_trailing_comma is None # the split by omitting optional parens isn't preferred by some other # reason - and not _prefer_split_rhs_oop(rhs_oop, line_length=line_length) + and not _prefer_split_rhs_oop(rhs_oop, mode) ): yield from _maybe_split_omitting_optional_parens( - rhs_oop, line, line_length, features=features, omit=omit + rhs_oop, line, mode, features=features, omit=omit ) return except CannotSplit as e: if not ( - can_be_split(rhs.body) - or is_line_short_enough(rhs.body, line_length=line_length) + can_be_split(rhs.body) or is_line_short_enough(rhs.body, mode=mode) ): raise CannotSplit( "Splitting failed, body is still too long and can't be split." @@ -817,7 +808,7 @@ def _maybe_split_omitting_optional_parens( yield result -def _prefer_split_rhs_oop(rhs_oop: _RHSResult, line_length: int) -> bool: +def _prefer_split_rhs_oop(rhs_oop: RHSResult, mode: Mode) -> bool: """ Returns whether we should prefer the result from a split omitting optional parens. """ @@ -837,7 +828,7 @@ def _prefer_split_rhs_oop(rhs_oop: _RHSResult, line_length: int) -> bool: # the first line still contains the `=`) any(leaf.type == token.EQUAL for leaf in rhs_oop.head.leaves) # the first line is short enough - and is_line_short_enough(rhs_oop.head, line_length=line_length) + and is_line_short_enough(rhs_oop.head, mode=mode) ) # contains unsplittable type ignore or rhs_oop.head.contains_unsplittable_type_ignore() @@ -928,10 +919,7 @@ def bracket_split_build_line( break leaves_to_track: Set[LeafID] = set() - if ( - Preview.handle_trailing_commas_in_head in original.mode - and component is _BracketSplitComponent.head - ): + if component is _BracketSplitComponent.head: leaves_to_track = get_leaves_inside_matching_brackets(leaves) # Populate the line for leaf in leaves: @@ -956,16 +944,39 @@ def dont_increase_indentation(split_func: Transformer) -> Transformer: """ @wraps(split_func) - def split_wrapper(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]: - for split_line in split_func(line, features): + def split_wrapper( + line: Line, features: Collection[Feature], mode: Mode + ) -> Iterator[Line]: + for split_line in split_func(line, features, mode): normalize_prefix(split_line.leaves[0], inside_brackets=True) yield split_line return split_wrapper +def _get_last_non_comment_leaf(line: Line) -> Optional[int]: + for leaf_idx in range(len(line.leaves) - 1, 0, -1): + if line.leaves[leaf_idx].type != STANDALONE_COMMENT: + return leaf_idx + return None + + +def _safe_add_trailing_comma(safe: bool, delimiter_priority: int, line: Line) -> Line: + if ( + safe + and delimiter_priority == COMMA_PRIORITY + and line.leaves[-1].type != token.COMMA + and line.leaves[-1].type != STANDALONE_COMMENT + ): + new_comma = Leaf(token.COMMA, ",") + line.append(new_comma) + return line + + @dont_increase_indentation -def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[Line]: +def delimiter_split( + line: Line, features: Collection[Feature], mode: Mode +) -> Iterator[Line]: """Split according to delimiters of the highest priority. If the appropriate Features are given, the split will add trailing commas @@ -1005,7 +1016,8 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ ) current_line.append(leaf) - for leaf in line.leaves: + last_non_comment_leaf = _get_last_non_comment_leaf(line) + for leaf_idx, leaf in enumerate(line.leaves): yield from append_to_line(leaf) for comment_after in line.comments_after(leaf): @@ -1022,6 +1034,15 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ trailing_comma_safe and Feature.TRAILING_COMMA_IN_CALL in features ) + if ( + Preview.add_trailing_comma_consistently in mode + and last_leaf.type == STANDALONE_COMMENT + and leaf_idx == last_non_comment_leaf + ): + current_line = _safe_add_trailing_comma( + trailing_comma_safe, delimiter_priority, current_line + ) + leaf_priority = bt.delimiters.get(id(leaf)) if leaf_priority == delimiter_priority: yield current_line @@ -1030,20 +1051,15 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[ mode=line.mode, depth=line.depth, inside_brackets=line.inside_brackets ) if current_line: - if ( - trailing_comma_safe - and delimiter_priority == COMMA_PRIORITY - and current_line.leaves[-1].type != token.COMMA - and current_line.leaves[-1].type != STANDALONE_COMMENT - ): - new_comma = Leaf(token.COMMA, ",") - current_line.append(new_comma) + current_line = _safe_add_trailing_comma( + trailing_comma_safe, delimiter_priority, current_line + ) yield current_line @dont_increase_indentation def standalone_comment_split( - line: Line, features: Collection[Feature] = () + line: Line, features: Collection[Feature], mode: Mode ) -> Iterator[Line]: """Split standalone comments from the rest of the line.""" if not line.contains_standalone_comments(0): @@ -1105,7 +1121,7 @@ def normalize_invisible_parens( Standardizes on visible parentheses for single-element tuples, and keeps existing visible parentheses for other tuples and generator expressions. """ - for pc in list_comments(node.prefix, is_endmarker=False, preview=mode.preview): + for pc in list_comments(node.prefix, is_endmarker=False): if pc.value in FMT_OFF: # This `node` has a prefix with `# fmt: off`, don't mess with parens. return @@ -1135,8 +1151,7 @@ def normalize_invisible_parens( if check_lpar: if ( - mode.preview - and child.type == syms.atom + child.type == syms.atom and node.type == syms.for_stmt and isinstance(child.prev_sibling, Leaf) and child.prev_sibling.type == token.NAME @@ -1148,9 +1163,7 @@ def normalize_invisible_parens( remove_brackets_around_comma=True, ): wrap_in_parentheses(node, child, visible=False) - elif ( - mode.preview and isinstance(child, Node) and node.type == syms.with_stmt - ): + elif isinstance(child, Node) and node.type == syms.with_stmt: remove_with_parens(child, node) elif child.type == syms.atom: if maybe_make_parens_invisible_in_atom( @@ -1176,7 +1189,7 @@ def normalize_invisible_parens( elif not (isinstance(child, Leaf) and is_multiline_string(child)): wrap_in_parentheses(node, child, visible=False) - comma_check = child.type == token.COMMA if mode.preview else False + comma_check = child.type == token.COMMA check_lpar = isinstance(child, Leaf) and ( child.value in parens_after or comma_check @@ -1219,18 +1232,19 @@ def remove_await_parens(node: Node) -> None: # N.B. We've still removed any redundant nested brackets though :) opening_bracket = cast(Leaf, node.children[1].children[0]) closing_bracket = cast(Leaf, node.children[1].children[-1]) - bracket_contents = cast(Node, node.children[1].children[1]) - if bracket_contents.type != syms.power: - ensure_visible(opening_bracket) - ensure_visible(closing_bracket) - elif ( - bracket_contents.type == syms.power - and bracket_contents.children[0].type == token.AWAIT - ): - ensure_visible(opening_bracket) - ensure_visible(closing_bracket) - # If we are in a nested await then recurse down. - remove_await_parens(bracket_contents) + bracket_contents = node.children[1].children[1] + if isinstance(bracket_contents, Node): + if bracket_contents.type != syms.power: + ensure_visible(opening_bracket) + ensure_visible(closing_bracket) + elif ( + bracket_contents.type == syms.power + and bracket_contents.children[0].type == token.AWAIT + ): + ensure_visible(opening_bracket) + ensure_visible(closing_bracket) + # If we are in a nested await then recurse down. + remove_await_parens(bracket_contents) def _maybe_wrap_cms_in_parens( @@ -1498,7 +1512,7 @@ def run_transformer( if not line_str: line_str = line_to_string(line) result: List[Line] = [] - for transformed_line in transform(line, features): + for transformed_line in transform(line, features, mode): if str(transformed_line).strip("\n") == line_str: raise CannotTransform("Line transformer returned an unchanged result") @@ -1513,7 +1527,7 @@ def run_transformer( or line.contains_multiline_strings() or result[0].contains_uncollapsable_type_comments() or result[0].contains_unsplittable_type_ignore() - or is_line_short_enough(result[0], line_length=mode.line_length) + or is_line_short_enough(result[0], mode=mode) # If any leaves have no parents (which _can_ occur since # `transform(line)` potentially destroys the line's underlying node # structure), then we can't proceed. Doing so would cause the below @@ -1528,8 +1542,6 @@ def run_transformer( second_opinion = run_transformer( line_copy, transform, mode, features_fop, line_str=line_str ) - if all( - is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion - ): + if all(is_line_short_enough(ln, mode=mode) for ln in second_opinion): result = second_opinion return result