]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/linegen.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Normalise string prefix order (#2297)
[etc/vim.git] / src / black / linegen.py
index fafaf1032ca55e002d49a5c3b7a8789f57f522e4..6008c773f943f0c22e2fec52eb6f5bfbcea29ed8 100644 (file)
@@ -5,12 +5,11 @@ from functools import partial, wraps
 import sys
 from typing import Collection, Iterator, List, Optional, Set, Union
 
 import sys
 from typing import Collection, Iterator, List, Optional, Set, Union
 
-from dataclasses import dataclass, field
-
 from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT
 from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS
 from black.nodes import Visitor, syms, first_child_is_arith, ensure_visible
 from black.nodes import is_docstring, is_empty_tuple, is_one_tuple, is_one_tuple_between
 from black.nodes import WHITESPACE, RARROW, STATEMENT, STANDALONE_COMMENT
 from black.nodes import ASSIGNMENTS, OPENING_BRACKETS, CLOSING_BRACKETS
 from black.nodes import Visitor, syms, first_child_is_arith, ensure_visible
 from black.nodes import is_docstring, is_empty_tuple, is_one_tuple, is_one_tuple_between
+from black.nodes import is_name_token, is_lpar_token, is_rpar_token
 from black.nodes import is_walrus_assignment, is_yield, is_vararg, is_multiline_string
 from black.nodes import is_stub_suite, is_stub_body, is_atom_with_invisible_parens
 from black.nodes import wrap_in_parentheses
 from black.nodes import is_walrus_assignment, is_yield, is_vararg, is_multiline_string
 from black.nodes import is_stub_suite, is_stub_body, is_atom_with_invisible_parens
 from black.nodes import wrap_in_parentheses
@@ -40,7 +39,8 @@ class CannotSplit(CannotTransform):
     """A readable split that fits the allotted line length is impossible."""
 
 
     """A readable split that fits the allotted line length is impossible."""
 
 
-@dataclass
+# This isn't a dataclass because @dataclass + Generic breaks mypyc.
+# See also https://github.com/mypyc/mypyc/issues/827.
 class LineGenerator(Visitor[Line]):
     """Generates reformatted Line objects.  Empty lines are not emitted.
 
 class LineGenerator(Visitor[Line]):
     """Generates reformatted Line objects.  Empty lines are not emitted.
 
@@ -48,9 +48,10 @@ class LineGenerator(Visitor[Line]):
     in ways that will no longer stringify to valid Python code on the tree.
     """
 
     in ways that will no longer stringify to valid Python code on the tree.
     """
 
-    mode: Mode
-    remove_u_prefix: bool = False
-    current_line: Line = field(init=False)
+    def __init__(self, mode: Mode) -> None:
+        self.mode = mode
+        self.current_line: Line
+        self.__post_init__()
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
 
     def line(self, indent: int = 0) -> Iterator[Line]:
         """Generate a line.
@@ -90,9 +91,7 @@ class LineGenerator(Visitor[Line]):
 
             normalize_prefix(node, inside_brackets=any_open_brackets)
             if self.mode.string_normalization and node.type == token.STRING:
 
             normalize_prefix(node, inside_brackets=any_open_brackets)
             if self.mode.string_normalization and node.type == token.STRING:
-                node.value = normalize_string_prefix(
-                    node.value, remove_u_prefix=self.remove_u_prefix
-                )
+                node.value = normalize_string_prefix(node.value)
                 node.value = normalize_string_quotes(node.value)
             if node.type == token.NUMBER:
                 normalize_numeric_literal(node)
                 node.value = normalize_string_quotes(node.value)
             if node.type == token.NUMBER:
                 normalize_numeric_literal(node)
@@ -126,7 +125,7 @@ class LineGenerator(Visitor[Line]):
         """Visit a statement.
 
         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
         """Visit a statement.
 
         This implementation is shared for `if`, `while`, `for`, `try`, `except`,
-        `def`, `with`, `class`, `assert` and assignments.
+        `def`, `with`, `class`, `assert`, and assignments.
 
         The relevant Python language `keywords` for a given statement will be
         NAME leaves within it. This methods puts those on a separate line.
 
         The relevant Python language `keywords` for a given statement will be
         NAME leaves within it. This methods puts those on a separate line.
@@ -136,11 +135,19 @@ class LineGenerator(Visitor[Line]):
         """
         normalize_invisible_parens(node, parens_after=parens)
         for child in node.children:
         """
         normalize_invisible_parens(node, parens_after=parens)
         for child in node.children:
-            if child.type == token.NAME and child.value in keywords:  # type: ignore
+            if is_name_token(child) and child.value in keywords:
                 yield from self.line()
 
             yield from self.visit(child)
 
                 yield from self.line()
 
             yield from self.visit(child)
 
+    def visit_match_case(self, node: Node) -> Iterator[Line]:
+        """Visit either a match or case statement."""
+        normalize_invisible_parens(node, parens_after=set())
+
+        yield from self.line()
+        for child in node.children:
+            yield from self.visit(child)
+
     def visit_suite(self, node: Node) -> Iterator[Line]:
         """Visit a suite."""
         if self.mode.is_pyi and is_stub_suite(node):
     def visit_suite(self, node: Node) -> Iterator[Line]:
         """Visit a suite."""
         if self.mode.is_pyi and is_stub_suite(node):
@@ -226,7 +233,7 @@ class LineGenerator(Visitor[Line]):
         if is_docstring(leaf) and "\\\n" not in leaf.value:
             # We're ignoring docstrings with backslash newline escapes because changing
             # indentation of those changes the AST representation of the code.
         if is_docstring(leaf) and "\\\n" not in leaf.value:
             # We're ignoring docstrings with backslash newline escapes because changing
             # indentation of those changes the AST representation of the code.
-            docstring = normalize_string_prefix(leaf.value, self.remove_u_prefix)
+            docstring = normalize_string_prefix(leaf.value)
             prefix = get_string_prefix(docstring)
             docstring = docstring[len(prefix) :]  # Remove the prefix
             quote_char = docstring[0]
             prefix = get_string_prefix(docstring)
             docstring = docstring[len(prefix) :]  # Remove the prefix
             quote_char = docstring[0]
@@ -292,6 +299,10 @@ class LineGenerator(Visitor[Line]):
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
         self.visit_async_funcdef = self.visit_async_stmt
         self.visit_decorated = self.visit_decorators
 
+        # PEP 634
+        self.visit_match_stmt = self.visit_match_case
+        self.visit_case_block = self.visit_match_case
+
 
 def transform_line(
     line: Line, mode: Mode, features: Collection[Feature] = ()
 
 def transform_line(
     line: Line, mode: Mode, features: Collection[Feature] = ()
@@ -335,7 +346,9 @@ def transform_line(
         transformers = [left_hand_split]
     else:
 
         transformers = [left_hand_split]
     else:
 
-        def rhs(line: Line, features: Collection[Feature]) -> Iterator[Line]:
+        def _rhs(
+            self: object, line: Line, features: Collection[Feature]
+        ) -> Iterator[Line]:
             """Wraps calls to `right_hand_split`.
 
             The calls increasingly `omit` right-hand trailers (bracket pairs with
             """Wraps calls to `right_hand_split`.
 
             The calls increasingly `omit` right-hand trailers (bracket pairs with
@@ -362,6 +375,12 @@ def transform_line(
                 line, line_length=mode.line_length, features=features
             )
 
                 line, line_length=mode.line_length, features=features
             )
 
+        # HACK: nested functions (like _rhs) compiled by mypyc don't retain their
+        # __name__ attribute which is needed in `run_transformer` further down.
+        # Unfortunately a nested class breaks mypyc too. So a class must be created
+        # via type ... https://github.com/mypyc/mypyc/issues/884
+        rhs = type("rhs", (), {"__call__": _rhs})()
+
         if mode.experimental_string_processing:
             if line.inside_brackets:
                 transformers = [
         if mode.experimental_string_processing:
             if line.inside_brackets:
                 transformers = [
@@ -503,14 +522,14 @@ def right_hand_split(
             yield from right_hand_split(line, line_length, features=features, omit=omit)
             return
 
             yield from right_hand_split(line, line_length, features=features, omit=omit)
             return
 
-        except CannotSplit:
+        except CannotSplit as e:
             if not (
                 can_be_split(body)
                 or is_line_short_enough(body, line_length=line_length)
             ):
                 raise CannotSplit(
                     "Splitting failed, body is still too long and can't be split."
             if not (
                 can_be_split(body)
                 or is_line_short_enough(body, line_length=line_length)
             ):
                 raise CannotSplit(
                     "Splitting failed, body is still too long and can't be split."
-                )
+                ) from e
 
             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
                 raise CannotSplit(
 
             elif head.contains_multiline_strings() or tail.contains_multiline_strings():
                 raise CannotSplit(
@@ -518,7 +537,7 @@ def right_hand_split(
                     " satisfy the splitting algorithm because the head or the tail"
                     " contains multiline strings which by definition never fit one"
                     " line."
                     " satisfy the splitting algorithm because the head or the tail"
                     " contains multiline strings which by definition never fit one"
                     " line."
-                )
+                ) from e
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
 
     ensure_visible(opening_bracket)
     ensure_visible(closing_bracket)
@@ -635,13 +654,13 @@ def delimiter_split(line: Line, features: Collection[Feature] = ()) -> Iterator[
     try:
         last_leaf = line.leaves[-1]
     except IndexError:
     try:
         last_leaf = line.leaves[-1]
     except IndexError:
-        raise CannotSplit("Line empty")
+        raise CannotSplit("Line empty") from None
 
     bt = line.bracket_tracker
     try:
         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
     except ValueError:
 
     bt = line.bracket_tracker
     try:
         delimiter_priority = bt.max_delimiter_priority(exclude={id(last_leaf)})
     except ValueError:
-        raise CannotSplit("No delimiters found")
+        raise CannotSplit("No delimiters found") from None
 
     if delimiter_priority == DOT_PRIORITY:
         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
 
     if delimiter_priority == DOT_PRIORITY:
         if bt.delimiter_count_with_priority(delimiter_priority) == 1:
@@ -792,10 +811,11 @@ def normalize_invisible_parens(node: Node, parens_after: Set[str]) -> None:
             elif node.type == syms.import_from:
                 # "import from" nodes store parentheses directly as part of
                 # the statement
             elif node.type == syms.import_from:
                 # "import from" nodes store parentheses directly as part of
                 # the statement
-                if child.type == token.LPAR:
+                if is_lpar_token(child):
+                    assert is_rpar_token(node.children[-1])
                     # make parentheses invisible
                     # make parentheses invisible
-                    child.value = ""  # type: ignore
-                    node.children[-1].value = ""  # type: ignore
+                    child.value = ""
+                    node.children[-1].value = ""
                 elif child.type != token.STAR:
                     # insert invisible parentheses
                     node.insert_child(index, Leaf(token.LPAR, ""))
                 elif child.type != token.STAR:
                     # insert invisible parentheses
                     node.insert_child(index, Leaf(token.LPAR, ""))
@@ -840,11 +860,11 @@ def maybe_make_parens_invisible_in_atom(node: LN, parent: LN) -> bool:
 
     first = node.children[0]
     last = node.children[-1]
 
     first = node.children[0]
     last = node.children[-1]
-    if first.type == token.LPAR and last.type == token.RPAR:
+    if is_lpar_token(first) and is_rpar_token(last):
         middle = node.children[1]
         # make parentheses invisible
         middle = node.children[1]
         # make parentheses invisible
-        first.value = ""  # type: ignore
-        last.value = ""  # type: ignore
+        first.value = ""
+        last.value = ""
         maybe_make_parens_invisible_in_atom(middle, parent=parent)
 
         if is_atom_with_invisible_parens(middle):
         maybe_make_parens_invisible_in_atom(middle, parent=parent)
 
         if is_atom_with_invisible_parens(middle):
@@ -976,7 +996,7 @@ def run_transformer(
         result.extend(transform_line(transformed_line, mode=mode, features=features))
 
     if (
         result.extend(transform_line(transformed_line, mode=mode, features=features))
 
     if (
-        transform.__name__ != "rhs"
+        transform.__class__.__name__ != "rhs"
         or not line.bracket_tracker.invisible
         or any(bracket.value for bracket in line.bracket_tracker.invisible)
         or line.contains_multiline_strings()
         or not line.bracket_tracker.invisible
         or any(bracket.value for bracket in line.bracket_tracker.invisible)
         or line.contains_multiline_strings()