X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/9c8464ca7ddd48d1c19112d895ae12d783f01563..935f303a0a7b794e722c7df00c906be285884874:/src/black/linegen.py?ds=sidebyside diff --git a/src/black/linegen.py b/src/black/linegen.py index f7d3655..faeb3ba 100644 --- a/src/black/linegen.py +++ b/src/black/linegen.py @@ -1,8 +1,9 @@ """ Generating lines of code. """ + import sys -from dataclasses import dataclass +from dataclasses import replace from enum import Enum, auto from functools import partial, wraps from typing import Collection, Iterator, List, Optional, Set, Union, cast @@ -16,6 +17,7 @@ from black.brackets import ( from black.comments import FMT_OFF, generate_comments, list_comments from black.lines import ( Line, + RHSResult, append_leaves, can_be_split, can_omit_invisible_parens, @@ -35,6 +37,7 @@ from black.nodes import ( Visitor, ensure_visible, is_arith_like, + is_async_stmt_or_funcdef, is_atom_with_invisible_parens, is_docstring, is_empty_tuple, @@ -47,6 +50,7 @@ from black.nodes import ( is_stub_body, is_stub_suite, is_tuple_containing_walrus, + is_type_ignore_comment_string, is_vararg, is_walrus_assignment, is_yield, @@ -109,6 +113,17 @@ class LineGenerator(Visitor[Line]): self.current_line.depth += indent return # Line is empty, don't emit. Creating a new one unnecessary. + if ( + Preview.improved_async_statements_handling in self.mode + and len(self.current_line.leaves) == 1 + and is_async_stmt_or_funcdef(self.current_line.leaves[0]) + ): + # Special case for async def/for/with statements. `visit_async_stmt` + # adds an `ASYNC` leaf then visits the child def/for/with statement + # nodes. Line yields from those nodes shouldn't treat the former + # `ASYNC` leaf as a complete line. + return + complete_line = self.current_line self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent) yield complete_line @@ -202,6 +217,18 @@ class LineGenerator(Visitor[Line]): yield from self.visit(child) + def visit_typeparams(self, node: Node) -> Iterator[Line]: + yield from self.visit_default(node) + node.children[0].prefix = "" + + def visit_typevartuple(self, node: Node) -> Iterator[Line]: + yield from self.visit_default(node) + node.children[1].prefix = "" + + def visit_paramspec(self, node: Node) -> Iterator[Line]: + yield from self.visit_default(node) + node.children[1].prefix = "" + def visit_dictsetmaker(self, node: Node) -> Iterator[Line]: if Preview.wrap_long_dict_values_in_parens in self.mode: for i, child in enumerate(node.children): @@ -255,7 +282,9 @@ class LineGenerator(Visitor[Line]): def visit_suite(self, node: Node) -> Iterator[Line]: """Visit a suite.""" - if self.mode.is_pyi and is_stub_suite(node): + if ( + self.mode.is_pyi or Preview.dummy_implementations in self.mode + ) and is_stub_suite(node): yield from self.visit(node.children[2]) else: yield from self.visit_default(node) @@ -270,7 +299,9 @@ class LineGenerator(Visitor[Line]): is_suite_like = node.parent and node.parent.type in STATEMENT if is_suite_like: - if self.mode.is_pyi and is_stub_body(node): + if ( + self.mode.is_pyi or Preview.dummy_implementations in self.mode + ) and is_stub_body(node): yield from self.visit_default(node) else: yield from self.line(+1) @@ -279,7 +310,7 @@ class LineGenerator(Visitor[Line]): else: if ( - not self.mode.is_pyi + not (self.mode.is_pyi or Preview.dummy_implementations in self.mode) or not node.parent or not is_stub_suite(node.parent) ): @@ -300,8 +331,11 @@ class LineGenerator(Visitor[Line]): break internal_stmt = next(children) - for child in internal_stmt.children: - yield from self.visit(child) + if Preview.improved_async_statements_handling in self.mode: + yield from self.visit(internal_stmt) + else: + for child in internal_stmt.children: + yield from self.visit(child) def visit_decorators(self, node: Node) -> Iterator[Line]: """Visit decorators.""" @@ -364,6 +398,24 @@ class LineGenerator(Visitor[Line]): node.insert_child(index, Node(syms.atom, [lpar, operand, rpar])) yield from self.visit_default(node) + def visit_tname(self, node: Node) -> Iterator[Line]: + """ + Add potential parentheses around types in function parameter lists to be made + into real parentheses in case the type hint is too long to fit on a line + Examples: + def foo(a: int, b: float = 7): ... + + -> + + def foo(a: (int), b: (float) = 7): ... + """ + if Preview.parenthesize_long_type_hints in self.mode: + assert len(node.children) == 3 + if maybe_make_parens_invisible_in_atom(node.children[2], parent=node): + wrap_in_parentheses(node, node.children[2], visible=False) + + yield from self.visit_default(node) + def visit_STRING(self, leaf: Leaf) -> Iterator[Line]: if Preview.hex_codes_in_unicode_sequences in self.mode: normalize_unicode_escape_sequences(leaf) @@ -465,7 +517,14 @@ class LineGenerator(Visitor[Line]): self.visit_except_clause = partial(v, keywords={"except"}, parens={"except"}) self.visit_with_stmt = partial(v, keywords={"with"}, parens={"with"}) self.visit_classdef = partial(v, keywords={"class"}, parens=Ø) - self.visit_expr_stmt = partial(v, keywords=Ø, parens=ASSIGNMENTS) + + # When this is moved out of preview, add ":" directly to ASSIGNMENTS in nodes.py + if Preview.parenthesize_long_type_hints in self.mode: + assignments = ASSIGNMENTS | {":"} + else: + assignments = ASSIGNMENTS + self.visit_expr_stmt = partial(v, keywords=Ø, parens=assignments) + self.visit_return_stmt = partial(v, keywords={"return"}, parens={"return"}) self.visit_import_from = partial(v, keywords=Ø, parens={"import"}) self.visit_del_stmt = partial(v, keywords=Ø, parens={"del"}) @@ -505,7 +564,7 @@ def transform_line( and not line.should_split_rhs and not line.magic_trailing_comma and ( - is_line_short_enough(line, line_length=mode.line_length, line_str=line_str) + is_line_short_enough(line, mode=mode, line_str=line_str) or line.contains_unsplittable_type_ignore() ) and not (line.inside_brackets and line.contains_standalone_comments()) @@ -515,7 +574,7 @@ def transform_line( transformers = [string_merge, string_paren_strip] else: transformers = [] - elif line.is_def: + elif line.is_def and not should_split_funcdef_with_rhs(line, mode): transformers = [left_hand_split] else: @@ -529,14 +588,12 @@ def transform_line( bracket pair instead. """ for omit in generate_trailers_to_omit(line, mode.line_length): - lines = list( - right_hand_split(line, mode.line_length, features, omit=omit) - ) + lines = list(right_hand_split(line, mode, features, omit=omit)) # Note: this check is only able to figure out if the first line of the # *current* transformation fits in the line length. This is true only # for simple cases. All others require running more transforms via # `transform_line()`. This check doesn't know if those would succeed. - if is_line_short_enough(lines[0], line_length=mode.line_length): + if is_line_short_enough(lines[0], mode=mode): yield from lines return @@ -544,9 +601,7 @@ def transform_line( # This mostly happens to multiline strings that are by definition # reported as not fitting a single line, as well as lines that contain # trailing commas (those have to be exploded). - yield from right_hand_split( - line, line_length=mode.line_length, features=features - ) + yield from right_hand_split(line, mode, features=features) # HACK: nested functions (like _rhs) compiled by mypyc don't retain their # __name__ attribute which is needed in `run_transformer` further down. @@ -598,6 +653,40 @@ def transform_line( yield line +def should_split_funcdef_with_rhs(line: Line, mode: Mode) -> bool: + """If a funcdef has a magic trailing comma in the return type, then we should first + split the line with rhs to respect the comma. + """ + if Preview.respect_magic_trailing_comma_in_return_type not in mode: + return False + + return_type_leaves: List[Leaf] = [] + in_return_type = False + + for leaf in line.leaves: + if leaf.type == token.COLON: + in_return_type = False + if in_return_type: + return_type_leaves.append(leaf) + if leaf.type == token.RARROW: + in_return_type = True + + # using `bracket_split_build_line` will mess with whitespace, so we duplicate a + # couple lines from it. + result = Line(mode=line.mode, depth=line.depth) + leaves_to_track = get_leaves_inside_matching_brackets(return_type_leaves) + for leaf in return_type_leaves: + result.append( + leaf, + preformatted=True, + track_bracket=id(leaf) in leaves_to_track, + ) + + # we could also return true if the line is too long, and the return type is longer + # than the param list. Or if `should_split_rhs` returns True. + return result.magic_trailing_comma is not None + + class _BracketSplitComponent(Enum): head = auto() body = auto() @@ -651,20 +740,9 @@ def left_hand_split( yield result -@dataclass -class _RHSResult: - """Intermediate split result from a right hand split.""" - - head: Line - body: Line - tail: Line - opening_bracket: Leaf - closing_bracket: Leaf - - def right_hand_split( line: Line, - line_length: int, + mode: Mode, features: Collection[Feature] = (), omit: Collection[LeafID] = (), ) -> Iterator[Line]: @@ -678,14 +756,14 @@ def right_hand_split( """ rhs_result = _first_right_hand_split(line, omit=omit) yield from _maybe_split_omitting_optional_parens( - rhs_result, line, line_length, features=features, omit=omit + rhs_result, line, mode, features=features, omit=omit ) def _first_right_hand_split( line: Line, omit: Collection[LeafID] = (), -) -> _RHSResult: +) -> RHSResult: """Split the line into head, body, tail starting with the last bracket pair. Note: this function should not have side effects. It's relied upon by @@ -727,13 +805,13 @@ def _first_right_hand_split( tail_leaves, line, opening_bracket, component=_BracketSplitComponent.tail ) bracket_split_succeeded_or_raise(head, body, tail) - return _RHSResult(head, body, tail, opening_bracket, closing_bracket) + return RHSResult(head, body, tail, opening_bracket, closing_bracket) def _maybe_split_omitting_optional_parens( - rhs: _RHSResult, + rhs: RHSResult, line: Line, - line_length: int, + mode: Mode, features: Collection[Feature] = (), omit: Collection[LeafID] = (), ) -> Iterator[Line]: @@ -751,11 +829,11 @@ def _maybe_split_omitting_optional_parens( # there are no standalone comments in the body and not rhs.body.contains_standalone_comments(0) # and we can actually remove the parens - and can_omit_invisible_parens(rhs.body, line_length) + and can_omit_invisible_parens(rhs, mode.line_length) ): omit = {id(rhs.closing_bracket), *omit} try: - # The _RHSResult Omitting Optional Parens. + # The RHSResult Omitting Optional Parens. rhs_oop = _first_right_hand_split(line, omit=omit) if not ( Preview.prefer_splitting_right_hand_side_of_assignments in line.mode @@ -766,23 +844,24 @@ def _maybe_split_omitting_optional_parens( and any(leaf.type in BRACKETS for leaf in rhs.head.leaves[:-1]) # the left side of assignment is short enough (the -1 is for the ending # optional paren) - and is_line_short_enough(rhs.head, line_length=line_length - 1) + and is_line_short_enough( + rhs.head, mode=replace(mode, line_length=mode.line_length - 1) + ) # the left side of assignment won't explode further because of magic # trailing comma and rhs.head.magic_trailing_comma is None # the split by omitting optional parens isn't preferred by some other # reason - and not _prefer_split_rhs_oop(rhs_oop, line_length=line_length) + and not _prefer_split_rhs_oop(rhs_oop, mode) ): yield from _maybe_split_omitting_optional_parens( - rhs_oop, line, line_length, features=features, omit=omit + rhs_oop, line, mode, features=features, omit=omit ) return except CannotSplit as e: if not ( - can_be_split(rhs.body) - or is_line_short_enough(rhs.body, line_length=line_length) + can_be_split(rhs.body) or is_line_short_enough(rhs.body, mode=mode) ): raise CannotSplit( "Splitting failed, body is still too long and can't be split." @@ -806,7 +885,7 @@ def _maybe_split_omitting_optional_parens( yield result -def _prefer_split_rhs_oop(rhs_oop: _RHSResult, line_length: int) -> bool: +def _prefer_split_rhs_oop(rhs_oop: RHSResult, mode: Mode) -> bool: """ Returns whether we should prefer the result from a split omitting optional parens. """ @@ -826,7 +905,7 @@ def _prefer_split_rhs_oop(rhs_oop: _RHSResult, line_length: int) -> bool: # the first line still contains the `=`) any(leaf.type == token.EQUAL for leaf in rhs_oop.head.leaves) # the first line is short enough - and is_line_short_enough(rhs_oop.head, line_length=line_length) + and is_line_short_enough(rhs_oop.head, mode=mode) ) # contains unsplittable type ignore or rhs_oop.head.contains_unsplittable_type_ignore() @@ -904,6 +983,13 @@ def bracket_split_build_line( ) if isinstance(node, Node) and isinstance(node.prev_sibling, Leaf) ) + # Except the false negatives above for PEP 604 unions where we + # can't add the comma. + and not ( + leaves[0].parent + and leaves[0].parent.next_sibling + and leaves[0].parent.next_sibling.type == token.VBAR + ) ) if original.is_import or no_commas: @@ -1342,7 +1428,7 @@ def maybe_make_parens_invisible_in_atom( Returns whether the node should itself be wrapped in invisible parentheses. """ if ( - node.type != syms.atom + node.type not in (syms.atom, syms.expr) or is_empty_tuple(node) or is_one_tuple(node) or (is_yield(node) and parent.type != syms.expr_stmt) @@ -1366,6 +1452,7 @@ def maybe_make_parens_invisible_in_atom( syms.except_clause, syms.funcdef, syms.with_stmt, + syms.tname, # these ones aren't useful to end users, but they do please fuzzers syms.for_stmt, syms.del_stmt, @@ -1378,8 +1465,13 @@ def maybe_make_parens_invisible_in_atom( if is_lpar_token(first) and is_rpar_token(last): middle = node.children[1] # make parentheses invisible - first.value = "" - last.value = "" + if ( + # If the prefix of `middle` includes a type comment with + # ignore annotation, then we do not remove the parentheses + not is_type_ignore_comment_string(middle.prefix.strip()) + ): + first.value = "" + last.value = "" maybe_make_parens_invisible_in_atom( middle, parent=parent, @@ -1525,7 +1617,7 @@ def run_transformer( or line.contains_multiline_strings() or result[0].contains_uncollapsable_type_comments() or result[0].contains_unsplittable_type_ignore() - or is_line_short_enough(result[0], line_length=mode.line_length) + or is_line_short_enough(result[0], mode=mode) # If any leaves have no parents (which _can_ occur since # `transform(line)` potentially destroys the line's underlying node # structure), then we can't proceed. Doing so would cause the below @@ -1540,8 +1632,6 @@ def run_transformer( second_opinion = run_transformer( line_copy, transform, mode, features_fop, line_str=line_str ) - if all( - is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion - ): + if all(is_line_short_enough(ln, mode=mode) for ln in second_opinion): result = second_opinion return result