X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/9c8464ca7ddd48d1c19112d895ae12d783f01563..1b028cc9d99c2c2e82f9b727742539173a92a373:/src/black/linegen.py?ds=sidebyside

diff --git a/src/black/linegen.py b/src/black/linegen.py
index f7d3655..5ef3bbd 100644
--- a/src/black/linegen.py
+++ b/src/black/linegen.py
@@ -2,7 +2,7 @@
 Generating lines of code.
 """
 import sys
-from dataclasses import dataclass
+from dataclasses import replace
 from enum import Enum, auto
 from functools import partial, wraps
 from typing import Collection, Iterator, List, Optional, Set, Union, cast
@@ -16,6 +16,7 @@ from black.brackets import (
 from black.comments import FMT_OFF, generate_comments, list_comments
 from black.lines import (
     Line,
+    RHSResult,
     append_leaves,
     can_be_split,
     can_omit_invisible_parens,
@@ -35,6 +36,7 @@ from black.nodes import (
     Visitor,
     ensure_visible,
     is_arith_like,
+    is_async_stmt_or_funcdef,
     is_atom_with_invisible_parens,
     is_docstring,
     is_empty_tuple,
@@ -47,6 +49,7 @@ from black.nodes import (
     is_stub_body,
     is_stub_suite,
     is_tuple_containing_walrus,
+    is_type_ignore_comment_string,
     is_vararg,
     is_walrus_assignment,
     is_yield,
@@ -109,6 +112,17 @@ class LineGenerator(Visitor[Line]):
             self.current_line.depth += indent
             return  # Line is empty, don't emit. Creating a new one unnecessary.
 
+        if (
+            Preview.improved_async_statements_handling in self.mode
+            and len(self.current_line.leaves) == 1
+            and is_async_stmt_or_funcdef(self.current_line.leaves[0])
+        ):
+            # Special case for async def/for/with statements. `visit_async_stmt`
+            # adds an `ASYNC` leaf then visits the child def/for/with statement
+            # nodes. Line yields from those nodes shouldn't treat the former
+            # `ASYNC` leaf as a complete line.
+            return
+
         complete_line = self.current_line
         self.current_line = Line(mode=self.mode, depth=complete_line.depth + indent)
         yield complete_line
@@ -202,6 +216,18 @@ class LineGenerator(Visitor[Line]):
 
             yield from self.visit(child)
 
+    def visit_typeparams(self, node: Node) -> Iterator[Line]:
+        yield from self.visit_default(node)
+        node.children[0].prefix = ""
+
+    def visit_typevartuple(self, node: Node) -> Iterator[Line]:
+        yield from self.visit_default(node)
+        node.children[1].prefix = ""
+
+    def visit_paramspec(self, node: Node) -> Iterator[Line]:
+        yield from self.visit_default(node)
+        node.children[1].prefix = ""
+
     def visit_dictsetmaker(self, node: Node) -> Iterator[Line]:
         if Preview.wrap_long_dict_values_in_parens in self.mode:
             for i, child in enumerate(node.children):
@@ -300,8 +326,11 @@ class LineGenerator(Visitor[Line]):
                 break
 
         internal_stmt = next(children)
-        for child in internal_stmt.children:
-            yield from self.visit(child)
+        if Preview.improved_async_statements_handling in self.mode:
+            yield from self.visit(internal_stmt)
+        else:
+            for child in internal_stmt.children:
+                yield from self.visit(child)
 
     def visit_decorators(self, node: Node) -> Iterator[Line]:
         """Visit decorators."""
@@ -505,7 +534,7 @@ def transform_line(
         and not line.should_split_rhs
         and not line.magic_trailing_comma
         and (
-            is_line_short_enough(line, line_length=mode.line_length, line_str=line_str)
+            is_line_short_enough(line, mode=mode, line_str=line_str)
             or line.contains_unsplittable_type_ignore()
         )
         and not (line.inside_brackets and line.contains_standalone_comments())
@@ -529,14 +558,12 @@ def transform_line(
             bracket pair instead.
             """
             for omit in generate_trailers_to_omit(line, mode.line_length):
-                lines = list(
-                    right_hand_split(line, mode.line_length, features, omit=omit)
-                )
+                lines = list(right_hand_split(line, mode, features, omit=omit))
                 # Note: this check is only able to figure out if the first line of the
                 # *current* transformation fits in the line length.  This is true only
                 # for simple cases.  All others require running more transforms via
                 # `transform_line()`.  This check doesn't know if those would succeed.
-                if is_line_short_enough(lines[0], line_length=mode.line_length):
+                if is_line_short_enough(lines[0], mode=mode):
                     yield from lines
                     return
 
@@ -544,9 +571,7 @@ def transform_line(
             # This mostly happens to multiline strings that are by definition
             # reported as not fitting a single line, as well as lines that contain
             # trailing commas (those have to be exploded).
-            yield from right_hand_split(
-                line, line_length=mode.line_length, features=features
-            )
+            yield from right_hand_split(line, mode, features=features)
 
         # HACK: nested functions (like _rhs) compiled by mypyc don't retain their
         # __name__ attribute which is needed in `run_transformer` further down.
@@ -651,20 +676,9 @@ def left_hand_split(
             yield result
 
 
-@dataclass
-class _RHSResult:
-    """Intermediate split result from a right hand split."""
-
-    head: Line
-    body: Line
-    tail: Line
-    opening_bracket: Leaf
-    closing_bracket: Leaf
-
-
 def right_hand_split(
     line: Line,
-    line_length: int,
+    mode: Mode,
     features: Collection[Feature] = (),
     omit: Collection[LeafID] = (),
 ) -> Iterator[Line]:
@@ -678,14 +692,14 @@ def right_hand_split(
     """
     rhs_result = _first_right_hand_split(line, omit=omit)
     yield from _maybe_split_omitting_optional_parens(
-        rhs_result, line, line_length, features=features, omit=omit
+        rhs_result, line, mode, features=features, omit=omit
     )
 
 
 def _first_right_hand_split(
     line: Line,
     omit: Collection[LeafID] = (),
-) -> _RHSResult:
+) -> RHSResult:
     """Split the line into head, body, tail starting with the last bracket pair.
 
     Note: this function should not have side effects. It's relied upon by
@@ -727,13 +741,13 @@ def _first_right_hand_split(
         tail_leaves, line, opening_bracket, component=_BracketSplitComponent.tail
     )
     bracket_split_succeeded_or_raise(head, body, tail)
-    return _RHSResult(head, body, tail, opening_bracket, closing_bracket)
+    return RHSResult(head, body, tail, opening_bracket, closing_bracket)
 
 
 def _maybe_split_omitting_optional_parens(
-    rhs: _RHSResult,
+    rhs: RHSResult,
     line: Line,
-    line_length: int,
+    mode: Mode,
     features: Collection[Feature] = (),
     omit: Collection[LeafID] = (),
 ) -> Iterator[Line]:
@@ -751,11 +765,11 @@ def _maybe_split_omitting_optional_parens(
         # there are no standalone comments in the body
         and not rhs.body.contains_standalone_comments(0)
         # and we can actually remove the parens
-        and can_omit_invisible_parens(rhs.body, line_length)
+        and can_omit_invisible_parens(rhs, mode.line_length)
     ):
         omit = {id(rhs.closing_bracket), *omit}
         try:
-            # The _RHSResult Omitting Optional Parens.
+            # The RHSResult Omitting Optional Parens.
             rhs_oop = _first_right_hand_split(line, omit=omit)
             if not (
                 Preview.prefer_splitting_right_hand_side_of_assignments in line.mode
@@ -766,23 +780,24 @@ def _maybe_split_omitting_optional_parens(
                 and any(leaf.type in BRACKETS for leaf in rhs.head.leaves[:-1])
                 # the left side of assignment is short enough (the -1 is for the ending
                 # optional paren)
-                and is_line_short_enough(rhs.head, line_length=line_length - 1)
+                and is_line_short_enough(
+                    rhs.head, mode=replace(mode, line_length=mode.line_length - 1)
+                )
                 # the left side of assignment won't explode further because of magic
                 # trailing comma
                 and rhs.head.magic_trailing_comma is None
                 # the split by omitting optional parens isn't preferred by some other
                 # reason
-                and not _prefer_split_rhs_oop(rhs_oop, line_length=line_length)
+                and not _prefer_split_rhs_oop(rhs_oop, mode)
             ):
                 yield from _maybe_split_omitting_optional_parens(
-                    rhs_oop, line, line_length, features=features, omit=omit
+                    rhs_oop, line, mode, features=features, omit=omit
                 )
                 return
 
         except CannotSplit as e:
             if not (
-                can_be_split(rhs.body)
-                or is_line_short_enough(rhs.body, line_length=line_length)
+                can_be_split(rhs.body) or is_line_short_enough(rhs.body, mode=mode)
             ):
                 raise CannotSplit(
                     "Splitting failed, body is still too long and can't be split."
@@ -806,7 +821,7 @@ def _maybe_split_omitting_optional_parens(
             yield result
 
 
-def _prefer_split_rhs_oop(rhs_oop: _RHSResult, line_length: int) -> bool:
+def _prefer_split_rhs_oop(rhs_oop: RHSResult, mode: Mode) -> bool:
     """
     Returns whether we should prefer the result from a split omitting optional parens.
     """
@@ -826,7 +841,7 @@ def _prefer_split_rhs_oop(rhs_oop: _RHSResult, line_length: int) -> bool:
             # the first line still contains the `=`)
             any(leaf.type == token.EQUAL for leaf in rhs_oop.head.leaves)
             # the first line is short enough
-            and is_line_short_enough(rhs_oop.head, line_length=line_length)
+            and is_line_short_enough(rhs_oop.head, mode=mode)
         )
         # contains unsplittable type ignore
         or rhs_oop.head.contains_unsplittable_type_ignore()
@@ -904,6 +919,13 @@ def bracket_split_build_line(
                     )
                     if isinstance(node, Node) and isinstance(node.prev_sibling, Leaf)
                 )
+                # Except the false negatives above for PEP 604 unions where we
+                # can't add the comma.
+                and not (
+                    leaves[0].parent
+                    and leaves[0].parent.next_sibling
+                    and leaves[0].parent.next_sibling.type == token.VBAR
+                )
             )
 
             if original.is_import or no_commas:
@@ -1378,8 +1400,13 @@ def maybe_make_parens_invisible_in_atom(
     if is_lpar_token(first) and is_rpar_token(last):
         middle = node.children[1]
         # make parentheses invisible
-        first.value = ""
-        last.value = ""
+        if (
+            # If the prefix of `middle` includes a type comment with
+            # ignore annotation, then we do not remove the parentheses
+            not is_type_ignore_comment_string(middle.prefix.strip())
+        ):
+            first.value = ""
+            last.value = ""
         maybe_make_parens_invisible_in_atom(
             middle,
             parent=parent,
@@ -1525,7 +1552,7 @@ def run_transformer(
         or line.contains_multiline_strings()
         or result[0].contains_uncollapsable_type_comments()
         or result[0].contains_unsplittable_type_ignore()
-        or is_line_short_enough(result[0], line_length=mode.line_length)
+        or is_line_short_enough(result[0], mode=mode)
         # If any leaves have no parents (which _can_ occur since
         # `transform(line)` potentially destroys the line's underlying node
         # structure), then we can't proceed. Doing so would cause the below
@@ -1540,8 +1567,6 @@ def run_transformer(
     second_opinion = run_transformer(
         line_copy, transform, mode, features_fop, line_str=line_str
     )
-    if all(
-        is_line_short_enough(ln, line_length=mode.line_length) for ln in second_opinion
-    ):
+    if all(is_line_short_enough(ln, mode=mode) for ln in second_opinion):
         result = second_opinion
     return result