X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/00e7e12a3a412ea386806d5d4eeaed345e912940..ffaaf4838228c922b586a87f717ed402031fcc0a:/src/black/trans.py?ds=sidebyside

diff --git a/src/black/trans.py b/src/black/trans.py
index ca620f6..8893ab0 100644
--- a/src/black/trans.py
+++ b/src/black/trans.py
@@ -1,13 +1,15 @@
 """
 String transformers that can split and merge strings.
 """
+import re
+import sys
 from abc import ABC, abstractmethod
 from collections import defaultdict
 from dataclasses import dataclass
-import regex as re
 from typing import (
     Any,
     Callable,
+    ClassVar,
     Collection,
     Dict,
     Iterable,
@@ -15,25 +17,43 @@ from typing import (
     List,
     Optional,
     Sequence,
+    Set,
     Tuple,
     TypeVar,
     Union,
 )
 
-from black.rusty import Result, Ok, Err
+if sys.version_info < (3, 8):
+    from typing_extensions import Final, Literal
+else:
+    from typing import Literal, Final
+
+from mypy_extensions import trait
 
-from black.mode import Feature
-from black.nodes import syms, replace_child, parent_type
-from black.nodes import is_empty_par, is_empty_lpar, is_empty_rpar
-from black.nodes import OPENING_BRACKETS, CLOSING_BRACKETS, STANDALONE_COMMENT
-from black.lines import Line, append_leaves
 from black.brackets import BracketMatchError
 from black.comments import contains_pragma_comment
-from black.strings import has_triple_quotes, get_string_prefix, assert_is_leaf_string
-from black.strings import normalize_string_quotes
-
-from blib2to3.pytree import Leaf, Node
+from black.lines import Line, append_leaves
+from black.mode import Feature
+from black.nodes import (
+    CLOSING_BRACKETS,
+    OPENING_BRACKETS,
+    STANDALONE_COMMENT,
+    is_empty_lpar,
+    is_empty_par,
+    is_empty_rpar,
+    parent_type,
+    replace_child,
+    syms,
+)
+from black.rusty import Err, Ok, Result
+from black.strings import (
+    assert_is_leaf_string,
+    get_string_prefix,
+    has_triple_quotes,
+    normalize_string_quotes,
+)
 from blib2to3.pgen2 import token
+from blib2to3.pytree import Leaf, Node
 
 
 class CannotTransform(Exception):
@@ -61,7 +81,84 @@ def TErr(err_msg: str) -> Err[CannotTransform]:
     return Err(cant_transform)
 
 
-@dataclass  # type: ignore
+def hug_power_op(line: Line, features: Collection[Feature]) -> Iterator[Line]:
+    """A transformer which normalizes spacing around power operators."""
+
+    # Performance optimization to avoid unnecessary Leaf clones and other ops.
+    for leaf in line.leaves:
+        if leaf.type == token.DOUBLESTAR:
+            break
+    else:
+        raise CannotTransform("No doublestar token was found in the line.")
+
+    def is_simple_lookup(index: int, step: Literal[1, -1]) -> bool:
+        # Brackets and parentheses indicate calls, subscripts, etc. ...
+        # basically stuff that doesn't count as "simple". Only a NAME lookup
+        # or dotted lookup (eg. NAME.NAME) is OK.
+        if step == -1:
+            disallowed = {token.RPAR, token.RSQB}
+        else:
+            disallowed = {token.LPAR, token.LSQB}
+
+        while 0 <= index < len(line.leaves):
+            current = line.leaves[index]
+            if current.type in disallowed:
+                return False
+            if current.type not in {token.NAME, token.DOT} or current.value == "for":
+                # If the current token isn't disallowed, we'll assume this is simple as
+                # only the disallowed tokens are semantically attached to this lookup
+                # expression we're checking. Also, stop early if we hit the 'for' bit
+                # of a comprehension.
+                return True
+
+            index += step
+
+        return True
+
+    def is_simple_operand(index: int, kind: Literal["base", "exponent"]) -> bool:
+        # An operand is considered "simple" if's a NAME, a numeric CONSTANT, a simple
+        # lookup (see above), with or without a preceding unary operator.
+        start = line.leaves[index]
+        if start.type in {token.NAME, token.NUMBER}:
+            return is_simple_lookup(index, step=(1 if kind == "exponent" else -1))
+
+        if start.type in {token.PLUS, token.MINUS, token.TILDE}:
+            if line.leaves[index + 1].type in {token.NAME, token.NUMBER}:
+                # step is always one as bases with a preceding unary op will be checked
+                # for simplicity starting from the next token (so it'll hit the check
+                # above).
+                return is_simple_lookup(index + 1, step=1)
+
+        return False
+
+    new_line = line.clone()
+    should_hug = False
+    for idx, leaf in enumerate(line.leaves):
+        new_leaf = leaf.clone()
+        if should_hug:
+            new_leaf.prefix = ""
+            should_hug = False
+
+        should_hug = (
+            (0 < idx < len(line.leaves) - 1)
+            and leaf.type == token.DOUBLESTAR
+            and is_simple_operand(idx - 1, kind="base")
+            and line.leaves[idx - 1].value != "lambda"
+            and is_simple_operand(idx + 1, kind="exponent")
+        )
+        if should_hug:
+            new_leaf.prefix = ""
+
+        # We have to be careful to make a new line properly:
+        # - bracket related metadata must be maintained (handled by Line.append)
+        # - comments need to copied over, updating the leaf IDs they're attached to
+        new_line.append(new_leaf, preformatted=True)
+        for comment_leaf in line.comments_after(leaf):
+            new_line.append(comment_leaf, preformatted=True)
+
+    yield new_line
+
+
 class StringTransformer(ABC):
     """
     An implementation of the Transformer protocol that relies on its
@@ -89,9 +186,13 @@ class StringTransformer(ABC):
         as much as possible.
     """
 
-    line_length: int
-    normalize_strings: bool
-    __name__ = "StringTransformer"
+    __name__: Final = "StringTransformer"
+
+    # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with
+    # `abc.ABC`.
+    def __init__(self, line_length: int, normalize_strings: bool) -> None:
+        self.line_length = line_length
+        self.normalize_strings = normalize_strings
 
     @abstractmethod
     def do_match(self, line: Line) -> TMatchResult:
@@ -183,6 +284,7 @@ class CustomSplit:
     break_idx: int
 
 
+@trait
 class CustomSplitMapMixin:
     """
     This mixin class is used to map merged strings to a sequence of
@@ -190,8 +292,10 @@ class CustomSplitMapMixin:
     the resultant substrings go over the configured max line length.
     """
 
-    _Key = Tuple[StringID, str]
-    _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
+    _Key: ClassVar = Tuple[StringID, str]
+    _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict(
+        tuple
+    )
 
     @staticmethod
     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
@@ -242,7 +346,7 @@ class CustomSplitMapMixin:
         return key in self._CUSTOM_SPLIT_MAP
 
 
-class StringMerger(CustomSplitMapMixin, StringTransformer):
+class StringMerger(StringTransformer, CustomSplitMapMixin):
     """StringTransformer that merges strings together.
 
     Requirements:
@@ -267,7 +371,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer):
 
         is_valid_index = is_valid_index_factory(LL)
 
-        for (i, leaf) in enumerate(LL):
+        for i, leaf in enumerate(LL):
             if (
                 leaf.type == token.STRING
                 and is_valid_index(i + 1)
@@ -437,7 +541,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer):
             # with 'f'...
             if "f" in prefix and "f" not in next_prefix:
                 # Then we must escape any braces contained in this substring.
-                SS = re.subf(r"(\{|\})", "{1}{1}", SS)
+                SS = re.sub(r"(\{|\})", r"\1\1", SS)
 
             NSS = make_naked(SS, next_prefix)
 
@@ -449,6 +553,9 @@ class StringMerger(CustomSplitMapMixin, StringTransformer):
 
             next_str_idx += 1
 
+        # Take a note on the index of the non-STRING leaf.
+        non_string_idx = next_str_idx
+
         S_leaf = Leaf(token.STRING, S)
         if self.normalize_strings:
             S_leaf.value = normalize_string_quotes(S_leaf.value)
@@ -468,11 +575,22 @@ class StringMerger(CustomSplitMapMixin, StringTransformer):
         string_leaf = Leaf(token.STRING, S_leaf.value.replace(BREAK_MARK, ""))
 
         if atom_node is not None:
-            replace_child(atom_node, string_leaf)
+            # If not all children of the atom node are merged (this can happen
+            # when there is a standalone comment in the middle) ...
+            if non_string_idx - string_idx < len(atom_node.children):
+                # We need to replace the old STRING leaves with the new string leaf.
+                first_child_idx = LL[string_idx].remove()
+                for idx in range(string_idx + 1, non_string_idx):
+                    LL[idx].remove()
+                if first_child_idx is not None:
+                    atom_node.insert_child(first_child_idx, string_leaf)
+            else:
+                # Else replace the atom node with the new string leaf.
+                replace_child(atom_node, string_leaf)
 
         # Build the final line ('new_line') that this method will later return.
         new_line = line.clone()
-        for (i, leaf) in enumerate(LL):
+        for i, leaf in enumerate(LL):
             if i == string_idx:
                 new_line.append(string_leaf)
 
@@ -593,7 +711,7 @@ class StringParenStripper(StringTransformer):
 
         is_valid_index = is_valid_index_factory(LL)
 
-        for (idx, leaf) in enumerate(LL):
+        for idx, leaf in enumerate(LL):
             # Should be a string...
             if leaf.type != token.STRING:
                 continue
@@ -738,7 +856,7 @@ class BaseStringSplitter(StringTransformer):
         * The target string is not a multiline (i.e. triple-quote) string.
     """
 
-    STRING_OPERATORS = [
+    STRING_OPERATORS: Final = [
         token.EQEQUAL,
         token.GREATER,
         token.GREATEREQUAL,
@@ -925,8 +1043,90 @@ class BaseStringSplitter(StringTransformer):
         max_string_length = self.line_length - offset
         return max_string_length
 
+    @staticmethod
+    def _prefer_paren_wrap_match(LL: List[Leaf]) -> Optional[int]:
+        """
+        Returns:
+            string_idx such that @LL[string_idx] is equal to our target (i.e.
+            matched) string, if this line matches the "prefer paren wrap" statement
+            requirements listed in the 'Requirements' section of the StringParenWrapper
+            class's docstring.
+                OR
+            None, otherwise.
+        """
+        # The line must start with a string.
+        if LL[0].type != token.STRING:
+            return None
+
+        # If the string is surrounded by commas (or is the first/last child)...
+        prev_sibling = LL[0].prev_sibling
+        next_sibling = LL[0].next_sibling
+        if not prev_sibling and not next_sibling and parent_type(LL[0]) == syms.atom:
+            # If it's an atom string, we need to check the parent atom's siblings.
+            parent = LL[0].parent
+            assert parent is not None  # For type checkers.
+            prev_sibling = parent.prev_sibling
+            next_sibling = parent.next_sibling
+        if (not prev_sibling or prev_sibling.type == token.COMMA) and (
+            not next_sibling or next_sibling.type == token.COMMA
+        ):
+            return 0
 
-class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
+        return None
+
+
+def iter_fexpr_spans(s: str) -> Iterator[Tuple[int, int]]:
+    """
+    Yields spans corresponding to expressions in a given f-string.
+    Spans are half-open ranges (left inclusive, right exclusive).
+    Assumes the input string is a valid f-string, but will not crash if the input
+    string is invalid.
+    """
+    stack: List[int] = []  # our curly paren stack
+    i = 0
+    while i < len(s):
+        if s[i] == "{":
+            # if we're in a string part of the f-string, ignore escaped curly braces
+            if not stack and i + 1 < len(s) and s[i + 1] == "{":
+                i += 2
+                continue
+            stack.append(i)
+            i += 1
+            continue
+
+        if s[i] == "}":
+            if not stack:
+                i += 1
+                continue
+            j = stack.pop()
+            # we've made it back out of the expression! yield the span
+            if not stack:
+                yield (j, i + 1)
+            i += 1
+            continue
+
+        # if we're in an expression part of the f-string, fast forward through strings
+        # note that backslashes are not legal in the expression portion of f-strings
+        if stack:
+            delim = None
+            if s[i : i + 3] in ("'''", '"""'):
+                delim = s[i : i + 3]
+            elif s[i] in ("'", '"'):
+                delim = s[i]
+            if delim:
+                i += len(delim)
+                while i < len(s) and s[i : i + len(delim)] != delim:
+                    i += 1
+                i += len(delim)
+                continue
+        i += 1
+
+
+def fstring_contains_expr(s: str) -> bool:
+    return any(iter_fexpr_spans(s))
+
+
+class StringSplitter(BaseStringSplitter, CustomSplitMapMixin):
     """
     StringTransformer that splits "atom" strings (i.e. strings which exist on
     lines by themselves).
@@ -964,22 +1164,14 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         CustomSplit objects and add them to the custom split map.
     """
 
-    MIN_SUBSTR_SIZE = 6
-    # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
-    RE_FEXPR = r"""
-    (?<!\{) (?:\{\{)* \{ (?!\{)
-        (?:
-            [^\{\}]
-            | \{\{
-            | \}\}
-            | (?R)
-        )+
-    \}
-    """
+    MIN_SUBSTR_SIZE: Final = 6
 
     def do_splitter_match(self, line: Line) -> TMatchResult:
         LL = line.leaves
 
+        if self._prefer_paren_wrap_match(LL) is not None:
+            return TErr("Line needs to be wrapped in parens first.")
+
         is_valid_index = is_valid_index_factory(LL)
 
         idx = 0
@@ -1042,15 +1234,15 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         # contain any f-expressions, but ONLY if the original f-string
         # contains at least one f-expression. Otherwise, we will alter the AST
         # of the program.
-        drop_pointless_f_prefix = ("f" in prefix) and re.search(
-            self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
+        drop_pointless_f_prefix = ("f" in prefix) and fstring_contains_expr(
+            LL[string_idx].value
         )
 
         first_string_line = True
 
         string_op_leaves = self._get_string_operator_leaves(LL)
         string_op_leaves_length = (
-            sum([len(str(prefix_leaf)) for prefix_leaf in string_op_leaves]) + 1
+            sum(len(str(prefix_leaf)) for prefix_leaf in string_op_leaves) + 1
             if string_op_leaves
             else 0
         )
@@ -1243,6 +1435,59 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
             last_line.comments = line.comments.copy()
             yield Ok(last_line)
 
+    def _iter_nameescape_slices(self, string: str) -> Iterator[Tuple[Index, Index]]:
+        """
+        Yields:
+            All ranges of @string which, if @string were to be split there,
+            would result in the splitting of an \\N{...} expression (which is NOT
+            allowed).
+        """
+        # True - the previous backslash was unescaped
+        # False - the previous backslash was escaped *or* there was no backslash
+        previous_was_unescaped_backslash = False
+        it = iter(enumerate(string))
+        for idx, c in it:
+            if c == "\\":
+                previous_was_unescaped_backslash = not previous_was_unescaped_backslash
+                continue
+            if not previous_was_unescaped_backslash or c != "N":
+                previous_was_unescaped_backslash = False
+                continue
+            previous_was_unescaped_backslash = False
+
+            begin = idx - 1  # the position of backslash before \N{...}
+            for idx, c in it:
+                if c == "}":
+                    end = idx
+                    break
+            else:
+                # malformed nameescape expression?
+                # should have been detected by AST parsing earlier...
+                raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
+            yield begin, end
+
+    def _iter_fexpr_slices(self, string: str) -> Iterator[Tuple[Index, Index]]:
+        """
+        Yields:
+            All ranges of @string which, if @string were to be split there,
+            would result in the splitting of an f-expression (which is NOT
+            allowed).
+        """
+        if "f" not in get_string_prefix(string).lower():
+            return
+        yield from iter_fexpr_spans(string)
+
+    def _get_illegal_split_indices(self, string: str) -> Set[Index]:
+        illegal_indices: Set[Index] = set()
+        iterators = [
+            self._iter_fexpr_slices(string),
+            self._iter_nameescape_slices(string),
+        ]
+        for it in iterators:
+            for begin, end in it:
+                illegal_indices.update(range(begin, end + 1))
+        return illegal_indices
+
     def _get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
         """
         This method contains the algorithm that StringSplitter uses to
@@ -1272,40 +1517,15 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         assert is_valid_index(max_break_idx)
         assert_is_leaf_string(string)
 
-        _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
+        _illegal_split_indices = self._get_illegal_split_indices(string)
 
-        def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
-            """
-            Yields:
-                All ranges of @string which, if @string were to be split there,
-                would result in the splitting of an f-expression (which is NOT
-                allowed).
-            """
-            nonlocal _fexpr_slices
-
-            if _fexpr_slices is None:
-                _fexpr_slices = []
-                for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
-                    _fexpr_slices.append(match.span())
-
-            yield from _fexpr_slices
-
-        is_fstring = "f" in get_string_prefix(string).lower()
-
-        def breaks_fstring_expression(i: Index) -> bool:
+        def breaks_unsplittable_expression(i: Index) -> bool:
             """
             Returns:
                 True iff returning @i would result in the splitting of an
-                f-expression (which is NOT allowed).
+                unsplittable expression (which is NOT allowed).
             """
-            if not is_fstring:
-                return False
-
-            for (start, end) in fexpr_slices():
-                if start <= i < end:
-                    return True
-
-            return False
+            return i in _illegal_split_indices
 
         def passes_all_checks(i: Index) -> bool:
             """
@@ -1329,7 +1549,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
                 is_space
                 and is_not_escaped
                 and is_big_enough
-                and not breaks_fstring_expression(i)
+                and not breaks_unsplittable_expression(i)
             )
 
         # First, we check all indices BELOW @max_break_idx.
@@ -1371,7 +1591,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         """
         assert_is_leaf_string(string)
 
-        if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
+        if "f" in prefix and not fstring_contains_expr(string):
             new_prefix = prefix.replace("f", "")
 
             temp = string[len(prefix) :]
@@ -1395,10 +1615,9 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         return string_op_leaves
 
 
-class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
+class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin):
     """
-    StringTransformer that splits non-"atom" strings (i.e. strings that do not
-    exist on lines by themselves).
+    StringTransformer that wraps strings in parens and then splits at the LPAR.
 
     Requirements:
         All of the requirements listed in BaseStringSplitter's docstring in
@@ -1418,6 +1637,10 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
             OR
         * The line is a dictionary key assignment where some valid key is being
         assigned the value of some string.
+            OR
+        * The line starts with an "atom" string that prefers to be wrapped in
+        parens. It's preferred to be wrapped when the string is surrounded by
+        commas (or is the first/last child).
 
     Transformations:
         The chosen string is wrapped in parentheses and then split at the LPAR.
@@ -1442,6 +1665,9 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
         changed such that it no longer needs to be given its own line,
         StringParenWrapper relies on StringParenStripper to clean up the
         parentheses it created.
+
+        For "atom" strings that prefers to be wrapped in parens, it requires
+        StringSplitter to hold the split until the string is wrapped in parens.
     """
 
     def do_splitter_match(self, line: Line) -> TMatchResult:
@@ -1458,6 +1684,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
             or self._assert_match(LL)
             or self._assign_match(LL)
             or self._dict_match(LL)
+            or self._prefer_paren_wrap_match(LL)
         )
 
         if string_idx is not None:
@@ -1547,7 +1774,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
             is_valid_index = is_valid_index_factory(LL)
 
-            for (i, leaf) in enumerate(LL):
+            for i, leaf in enumerate(LL):
                 # We MUST find a comma...
                 if leaf.type == token.COMMA:
                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
@@ -1585,7 +1812,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
         ):
             is_valid_index = is_valid_index_factory(LL)
 
-            for (i, leaf) in enumerate(LL):
+            for i, leaf in enumerate(LL):
                 # We MUST find either an '=' or '+=' symbol...
                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
@@ -1628,7 +1855,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
             is_valid_index = is_valid_index_factory(LL)
 
-            for (i, leaf) in enumerate(LL):
+            for i, leaf in enumerate(LL):
                 # We MUST find a colon...
                 if leaf.type == token.COLON:
                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
@@ -1780,20 +2007,20 @@ class StringParser:
         ```
     """
 
-    DEFAULT_TOKEN = -1
+    DEFAULT_TOKEN: Final = 20210605
 
     # String Parser States
-    START = 1
-    DOT = 2
-    NAME = 3
-    PERCENT = 4
-    SINGLE_FMT_ARG = 5
-    LPAR = 6
-    RPAR = 7
-    DONE = 8
+    START: Final = 1
+    DOT: Final = 2
+    NAME: Final = 3
+    PERCENT: Final = 4
+    SINGLE_FMT_ARG: Final = 5
+    LPAR: Final = 6
+    RPAR: Final = 7
+    DONE: Final = 8
 
     # Lookup Table for Next State
-    _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
+    _goto: Final[Dict[Tuple[ParserState, NodeType], ParserState]] = {
         # A string trailer may start with '.' OR '%'.
         (START, token.DOT): DOT,
         (START, token.PERCENT): PERCENT,