X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/f2ea461e9e9fa5c47bb61fd72d512c748928badc..5543d1b55a2a485f2aaf32156ea97f4728264137:/src/black/trans.py?ds=sidebyside

diff --git a/src/black/trans.py b/src/black/trans.py
index 055f33c..cb41c1b 100644
--- a/src/black/trans.py
+++ b/src/black/trans.py
@@ -4,10 +4,11 @@ String transformers that can split and merge strings.
 from abc import ABC, abstractmethod
 from collections import defaultdict
 from dataclasses import dataclass
-import regex as re
+import re
 from typing import (
     Any,
     Callable,
+    ClassVar,
     Collection,
     Dict,
     Iterable,
@@ -15,17 +16,26 @@ from typing import (
     List,
     Optional,
     Sequence,
+    Set,
     Tuple,
     TypeVar,
     Union,
 )
+import sys
+
+if sys.version_info < (3, 8):
+    from typing_extensions import Final
+else:
+    from typing import Final
+
+from mypy_extensions import trait
 
 from black.rusty import Result, Ok, Err
 
 from black.mode import Feature
 from black.nodes import syms, replace_child, parent_type
 from black.nodes import is_empty_par, is_empty_lpar, is_empty_rpar
-from black.nodes import CLOSING_BRACKETS, STANDALONE_COMMENT
+from black.nodes import OPENING_BRACKETS, CLOSING_BRACKETS, STANDALONE_COMMENT
 from black.lines import Line, append_leaves
 from black.brackets import BracketMatchError
 from black.comments import contains_pragma_comment
@@ -61,7 +71,6 @@ def TErr(err_msg: str) -> Err[CannotTransform]:
     return Err(cant_transform)
 
 
-@dataclass  # type: ignore
 class StringTransformer(ABC):
     """
     An implementation of the Transformer protocol that relies on its
@@ -89,9 +98,13 @@ class StringTransformer(ABC):
         as much as possible.
     """
 
-    line_length: int
-    normalize_strings: bool
-    __name__ = "StringTransformer"
+    __name__: Final = "StringTransformer"
+
+    # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with
+    # `abc.ABC`.
+    def __init__(self, line_length: int, normalize_strings: bool) -> None:
+        self.line_length = line_length
+        self.normalize_strings = normalize_strings
 
     @abstractmethod
     def do_match(self, line: Line) -> TMatchResult:
@@ -183,6 +196,7 @@ class CustomSplit:
     break_idx: int
 
 
+@trait
 class CustomSplitMapMixin:
     """
     This mixin class is used to map merged strings to a sequence of
@@ -190,8 +204,10 @@ class CustomSplitMapMixin:
     the resultant substrings go over the configured max line length.
     """
 
-    _Key = Tuple[StringID, str]
-    _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple)
+    _Key: ClassVar = Tuple[StringID, str]
+    _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict(
+        tuple
+    )
 
     @staticmethod
     def _get_key(string: str) -> "CustomSplitMapMixin._Key":
@@ -242,7 +258,7 @@ class CustomSplitMapMixin:
         return key in self._CUSTOM_SPLIT_MAP
 
 
-class StringMerger(CustomSplitMapMixin, StringTransformer):
+class StringMerger(StringTransformer, CustomSplitMapMixin):
     """StringTransformer that merges strings together.
 
     Requirements:
@@ -411,7 +427,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer):
             and is_valid_index(next_str_idx)
             and LL[next_str_idx].type == token.STRING
         ):
-            prefix = get_string_prefix(LL[next_str_idx].value)
+            prefix = get_string_prefix(LL[next_str_idx].value).lower()
             next_str_idx += 1
 
         # The next loop merges the string group. The final string will be
@@ -431,13 +447,13 @@ class StringMerger(CustomSplitMapMixin, StringTransformer):
             num_of_strings += 1
 
             SS = LL[next_str_idx].value
-            next_prefix = get_string_prefix(SS)
+            next_prefix = get_string_prefix(SS).lower()
 
             # If this is an f-string group but this substring is not prefixed
             # with 'f'...
             if "f" in prefix and "f" not in next_prefix:
                 # Then we must escape any braces contained in this substring.
-                SS = re.subf(r"(\{|\})", "{1}{1}", SS)
+                SS = re.sub(r"(\{|\})", r"\1\1", SS)
 
             NSS = make_naked(SS, next_prefix)
 
@@ -541,7 +557,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer):
                 return TErr("StringMerger does NOT merge multiline strings.")
 
             num_of_strings += 1
-            prefix = get_string_prefix(leaf.value)
+            prefix = get_string_prefix(leaf.value).lower()
             if "r" in prefix:
                 return TErr("StringMerger does NOT merge raw strings.")
 
@@ -576,7 +592,7 @@ class StringParenStripper(StringTransformer):
             - The target string is NOT the only argument to a function call.
             - The target string is NOT a "pointless" string.
             - If the target string contains a PERCENT, the brackets are not
-              preceeded or followed by an operator with higher precedence than
+              preceded or followed by an operator with higher precedence than
               PERCENT.
 
     Transformations:
@@ -738,6 +754,18 @@ class BaseStringSplitter(StringTransformer):
         * The target string is not a multiline (i.e. triple-quote) string.
     """
 
+    STRING_OPERATORS: Final = [
+        token.EQEQUAL,
+        token.GREATER,
+        token.GREATEREQUAL,
+        token.LESS,
+        token.LESSEQUAL,
+        token.NOTEQUAL,
+        token.PERCENT,
+        token.PLUS,
+        token.STAR,
+    ]
+
     @abstractmethod
     def do_splitter_match(self, line: Line) -> TMatchResult:
         """
@@ -847,15 +875,15 @@ class BaseStringSplitter(StringTransformer):
                 p_idx -= 1
 
             P = LL[p_idx]
-            if P.type == token.PLUS:
-                # WMA4 a space and a '+' character (e.g. `+ STRING`).
-                offset += 2
+            if P.type in self.STRING_OPERATORS:
+                # WMA4 a space and a string operator (e.g. `+ STRING` or `== STRING`).
+                offset += len(str(P)) + 1
 
             if P.type == token.COMMA:
                 # WMA4 a space, a comma, and a closing bracket [e.g. `), STRING`].
                 offset += 3
 
-            if P.type in [token.COLON, token.EQUAL, token.NAME]:
+            if P.type in [token.COLON, token.EQUAL, token.PLUSEQUAL, token.NAME]:
                 # This conditional branch is meant to handle dictionary keys,
                 # variable assignments, 'return STRING' statement lines, and
                 # 'else STRING' ternary expression lines.
@@ -914,15 +942,66 @@ class BaseStringSplitter(StringTransformer):
         return max_string_length
 
 
-class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
+def iter_fexpr_spans(s: str) -> Iterator[Tuple[int, int]]:
+    """
+    Yields spans corresponding to expressions in a given f-string.
+    Spans are half-open ranges (left inclusive, right exclusive).
+    Assumes the input string is a valid f-string, but will not crash if the input
+    string is invalid.
+    """
+    stack: List[int] = []  # our curly paren stack
+    i = 0
+    while i < len(s):
+        if s[i] == "{":
+            # if we're in a string part of the f-string, ignore escaped curly braces
+            if not stack and i + 1 < len(s) and s[i + 1] == "{":
+                i += 2
+                continue
+            stack.append(i)
+            i += 1
+            continue
+
+        if s[i] == "}":
+            if not stack:
+                i += 1
+                continue
+            j = stack.pop()
+            # we've made it back out of the expression! yield the span
+            if not stack:
+                yield (j, i + 1)
+            i += 1
+            continue
+
+        # if we're in an expression part of the f-string, fast forward through strings
+        # note that backslashes are not legal in the expression portion of f-strings
+        if stack:
+            delim = None
+            if s[i : i + 3] in ("'''", '"""'):
+                delim = s[i : i + 3]
+            elif s[i] in ("'", '"'):
+                delim = s[i]
+            if delim:
+                i += len(delim)
+                while i < len(s) and s[i : i + len(delim)] != delim:
+                    i += 1
+                i += len(delim)
+                continue
+        i += 1
+
+
+def fstring_contains_expr(s: str) -> bool:
+    return any(iter_fexpr_spans(s))
+
+
+class StringSplitter(BaseStringSplitter, CustomSplitMapMixin):
     """
     StringTransformer that splits "atom" strings (i.e. strings which exist on
     lines by themselves).
 
     Requirements:
-        * The line consists ONLY of a single string (with the exception of a
-        '+' symbol which MAY exist at the start of the line), MAYBE a string
-        trailer, and MAYBE a trailing comma.
+        * The line consists ONLY of a single string (possibly prefixed by a
+        string operator [e.g. '+' or '==']), MAYBE a string trailer, and MAYBE
+        a trailing comma.
             AND
         * All of the requirements listed in BaseStringSplitter's docstring.
 
@@ -952,18 +1031,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         CustomSplit objects and add them to the custom split map.
     """
 
-    MIN_SUBSTR_SIZE = 6
-    # Matches an "f-expression" (e.g. {var}) that might be found in an f-string.
-    RE_FEXPR = r"""
-    (?<!\{) (?:\{\{)* \{ (?!\{)
-        (?:
-            [^\{\}]
-            | \{\{
-            | \}\}
-            | (?R)
-        )+?
-    (?<!\}) \} (?:\}\})* (?!\})
-    """
+    MIN_SUBSTR_SIZE: Final = 6
 
     def do_splitter_match(self, line: Line) -> TMatchResult:
         LL = line.leaves
@@ -972,8 +1040,20 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
 
         idx = 0
 
-        # The first leaf MAY be a '+' symbol...
-        if is_valid_index(idx) and LL[idx].type == token.PLUS:
+        # The first two leaves MAY be the 'not in' keywords...
+        if (
+            is_valid_index(idx)
+            and is_valid_index(idx + 1)
+            and [LL[idx].type, LL[idx + 1].type] == [token.NAME, token.NAME]
+            and str(LL[idx]) + str(LL[idx + 1]) == "not in"
+        ):
+            idx += 2
+        # Else the first leaf MAY be a string operator symbol or the 'in' keyword...
+        elif is_valid_index(idx) and (
+            LL[idx].type in self.STRING_OPERATORS
+            or LL[idx].type == token.NAME
+            and str(LL[idx]) == "in"
+        ):
             idx += 1
 
         # The next/first leaf MAY be an empty LPAR...
@@ -1012,34 +1092,37 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         is_valid_index = is_valid_index_factory(LL)
         insert_str_child = insert_str_child_factory(LL[string_idx])
 
-        prefix = get_string_prefix(LL[string_idx].value)
+        prefix = get_string_prefix(LL[string_idx].value).lower()
 
         # We MAY choose to drop the 'f' prefix from substrings that don't
         # contain any f-expressions, but ONLY if the original f-string
         # contains at least one f-expression. Otherwise, we will alter the AST
         # of the program.
-        drop_pointless_f_prefix = ("f" in prefix) and re.search(
-            self.RE_FEXPR, LL[string_idx].value, re.VERBOSE
+        drop_pointless_f_prefix = ("f" in prefix) and fstring_contains_expr(
+            LL[string_idx].value
         )
 
         first_string_line = True
-        starts_with_plus = LL[0].type == token.PLUS
 
-        def line_needs_plus() -> bool:
-            return first_string_line and starts_with_plus
+        string_op_leaves = self._get_string_operator_leaves(LL)
+        string_op_leaves_length = (
+            sum([len(str(prefix_leaf)) for prefix_leaf in string_op_leaves]) + 1
+            if string_op_leaves
+            else 0
+        )
 
-        def maybe_append_plus(new_line: Line) -> None:
+        def maybe_append_string_operators(new_line: Line) -> None:
             """
             Side Effects:
-                If @line starts with a plus and this is the first line we are
-                constructing, this function appends a PLUS leaf to @new_line
-                and replaces the old PLUS leaf in the node structure. Otherwise
-                this function does nothing.
+                If @line starts with a string operator and this is the first
+                line we are constructing, this function appends the string
+                operator to @new_line and replaces the old string operator leaf
+                in the node structure. Otherwise this function does nothing.
             """
-            if line_needs_plus():
-                plus_leaf = Leaf(token.PLUS, "+")
-                replace_child(LL[0], plus_leaf)
-                new_line.append(plus_leaf)
+            maybe_prefix_leaves = string_op_leaves if first_string_line else []
+            for i, prefix_leaf in enumerate(maybe_prefix_leaves):
+                replace_child(LL[i], prefix_leaf)
+                new_line.append(prefix_leaf)
 
         ends_with_comma = (
             is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA
@@ -1054,7 +1137,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
             result = self.line_length
             result -= line.depth * 4
             result -= 1 if ends_with_comma else 0
-            result -= 2 if line_needs_plus() else 0
+            result -= string_op_leaves_length
             return result
 
         # --- Calculate Max Break Index (for string value)
@@ -1103,7 +1186,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
                 break_idx = csplit.break_idx
             else:
                 # Algorithmic Split (automatic)
-                max_bidx = max_break_idx - 2 if line_needs_plus() else max_break_idx
+                max_bidx = max_break_idx - string_op_leaves_length
                 maybe_break_idx = self._get_break_idx(rest_value, max_bidx)
                 if maybe_break_idx is None:
                     # If we are unable to algorithmically determine a good split
@@ -1124,21 +1207,32 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
 
             # --- Construct `next_value`
             next_value = rest_value[:break_idx] + QUOTE
+
+            # HACK: The following 'if' statement is a hack to fix the custom
+            # breakpoint index in the case of either: (a) substrings that were
+            # f-strings but will have the 'f' prefix removed OR (b) substrings
+            # that were not f-strings but will now become f-strings because of
+            # redundant use of the 'f' prefix (i.e. none of the substrings
+            # contain f-expressions but one or more of them had the 'f' prefix
+            # anyway; in which case, we will prepend 'f' to _all_ substrings).
+            #
+            # There is probably a better way to accomplish what is being done
+            # here...
+            #
+            # If this substring is an f-string, we _could_ remove the 'f'
+            # prefix, and the current custom split did NOT originally use a
+            # prefix...
             if (
-                # Are we allowed to try to drop a pointless 'f' prefix?
-                drop_pointless_f_prefix
-                # If we are, will we be successful?
-                and next_value != self._normalize_f_string(next_value, prefix)
+                next_value != self._normalize_f_string(next_value, prefix)
+                and use_custom_breakpoints
+                and not csplit.has_prefix
             ):
-                # If the current custom split did NOT originally use a prefix,
-                # then `csplit.break_idx` will be off by one after removing
+                # Then `csplit.break_idx` will be off by one after removing
                 # the 'f' prefix.
-                break_idx = (
-                    break_idx + 1
-                    if use_custom_breakpoints and not csplit.has_prefix
-                    else break_idx
-                )
+                break_idx += 1
                 next_value = rest_value[:break_idx] + QUOTE
+
+            if drop_pointless_f_prefix:
                 next_value = self._normalize_f_string(next_value, prefix)
 
             # --- Construct `next_leaf`
@@ -1148,7 +1242,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
 
             # --- Construct `next_line`
             next_line = line.clone()
-            maybe_append_plus(next_line)
+            maybe_append_string_operators(next_line)
             next_line.append(next_leaf)
             string_line_results.append(Ok(next_line))
 
@@ -1169,7 +1263,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         self._maybe_normalize_string_quotes(rest_leaf)
 
         last_line = line.clone()
-        maybe_append_plus(last_line)
+        maybe_append_string_operators(last_line)
 
         # If there are any leaves to the right of the target string...
         if is_valid_index(string_idx + 1):
@@ -1205,6 +1299,59 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
             last_line.comments = line.comments.copy()
             yield Ok(last_line)
 
+    def _iter_nameescape_slices(self, string: str) -> Iterator[Tuple[Index, Index]]:
+        """
+        Yields:
+            All ranges of @string which, if @string were to be split there,
+            would result in the splitting of an \\N{...} expression (which is NOT
+            allowed).
+        """
+        # True - the previous backslash was unescaped
+        # False - the previous backslash was escaped *or* there was no backslash
+        previous_was_unescaped_backslash = False
+        it = iter(enumerate(string))
+        for idx, c in it:
+            if c == "\\":
+                previous_was_unescaped_backslash = not previous_was_unescaped_backslash
+                continue
+            if not previous_was_unescaped_backslash or c != "N":
+                previous_was_unescaped_backslash = False
+                continue
+            previous_was_unescaped_backslash = False
+
+            begin = idx - 1  # the position of backslash before \N{...}
+            for idx, c in it:
+                if c == "}":
+                    end = idx
+                    break
+            else:
+                # malformed nameescape expression?
+                # should have been detected by AST parsing earlier...
+                raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!")
+            yield begin, end
+
+    def _iter_fexpr_slices(self, string: str) -> Iterator[Tuple[Index, Index]]:
+        """
+        Yields:
+            All ranges of @string which, if @string were to be split there,
+            would result in the splitting of an f-expression (which is NOT
+            allowed).
+        """
+        if "f" not in get_string_prefix(string).lower():
+            return
+        yield from iter_fexpr_spans(string)
+
+    def _get_illegal_split_indices(self, string: str) -> Set[Index]:
+        illegal_indices: Set[Index] = set()
+        iterators = [
+            self._iter_fexpr_slices(string),
+            self._iter_nameescape_slices(string),
+        ]
+        for it in iterators:
+            for begin, end in it:
+                illegal_indices.update(range(begin, end + 1))
+        return illegal_indices
+
     def _get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]:
         """
         This method contains the algorithm that StringSplitter uses to
@@ -1234,40 +1381,15 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         assert is_valid_index(max_break_idx)
         assert_is_leaf_string(string)
 
-        _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None
+        _illegal_split_indices = self._get_illegal_split_indices(string)
 
-        def fexpr_slices() -> Iterator[Tuple[Index, Index]]:
-            """
-            Yields:
-                All ranges of @string which, if @string were to be split there,
-                would result in the splitting of an f-expression (which is NOT
-                allowed).
-            """
-            nonlocal _fexpr_slices
-
-            if _fexpr_slices is None:
-                _fexpr_slices = []
-                for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE):
-                    _fexpr_slices.append(match.span())
-
-            yield from _fexpr_slices
-
-        is_fstring = "f" in get_string_prefix(string)
-
-        def breaks_fstring_expression(i: Index) -> bool:
+        def breaks_unsplittable_expression(i: Index) -> bool:
             """
             Returns:
                 True iff returning @i would result in the splitting of an
-                f-expression (which is NOT allowed).
+                unsplittable expression (which is NOT allowed).
             """
-            if not is_fstring:
-                return False
-
-            for (start, end) in fexpr_slices():
-                if start <= i < end:
-                    return True
-
-            return False
+            return i in _illegal_split_indices
 
         def passes_all_checks(i: Index) -> bool:
             """
@@ -1291,7 +1413,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
                 is_space
                 and is_not_escaped
                 and is_big_enough
-                and not breaks_fstring_expression(i)
+                and not breaks_unsplittable_expression(i)
             )
 
         # First, we check all indices BELOW @max_break_idx.
@@ -1333,7 +1455,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         """
         assert_is_leaf_string(string)
 
-        if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE):
+        if "f" in prefix and not fstring_contains_expr(string):
             new_prefix = prefix.replace("f", "")
 
             temp = string[len(prefix) :]
@@ -1345,8 +1467,19 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter):
         else:
             return string
 
+    def _get_string_operator_leaves(self, leaves: Iterable[Leaf]) -> List[Leaf]:
+        LL = list(leaves)
 
-class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
+        string_op_leaves = []
+        i = 0
+        while LL[i].type in self.STRING_OPERATORS + [token.NAME]:
+            prefix_leaf = Leaf(LL[i].type, str(LL[i]).strip())
+            string_op_leaves.append(prefix_leaf)
+            i += 1
+        return string_op_leaves
+
+
+class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin):
     """
     StringTransformer that splits non-"atom" strings (i.e. strings that do not
     exist on lines by themselves).
@@ -1398,6 +1531,11 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
     def do_splitter_match(self, line: Line) -> TMatchResult:
         LL = line.leaves
 
+        if line.leaves[-1].type in OPENING_BRACKETS:
+            return TErr(
+                "Cannot wrap parens around a line that ends in an opening bracket."
+            )
+
         string_idx = (
             self._return_match(LL)
             or self._else_match(LL)
@@ -1665,9 +1803,10 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter):
                 right_leaves.pop()
 
             if old_parens_exist:
-                assert (
-                    right_leaves and right_leaves[-1].type == token.RPAR
-                ), "Apparently, old parentheses do NOT exist?!"
+                assert right_leaves and right_leaves[-1].type == token.RPAR, (
+                    "Apparently, old parentheses do NOT exist?!"
+                    f" (left_leaves={left_leaves}, right_leaves={right_leaves})"
+                )
                 old_rpar_leaf = right_leaves.pop()
 
             append_leaves(string_line, line, right_leaves)
@@ -1725,20 +1864,20 @@ class StringParser:
         ```
     """
 
-    DEFAULT_TOKEN = -1
+    DEFAULT_TOKEN: Final = 20210605
 
     # String Parser States
-    START = 1
-    DOT = 2
-    NAME = 3
-    PERCENT = 4
-    SINGLE_FMT_ARG = 5
-    LPAR = 6
-    RPAR = 7
-    DONE = 8
+    START: Final = 1
+    DOT: Final = 2
+    NAME: Final = 3
+    PERCENT: Final = 4
+    SINGLE_FMT_ARG: Final = 5
+    LPAR: Final = 6
+    RPAR: Final = 7
+    DONE: Final = 8
 
     # Lookup Table for Next State
-    _goto: Dict[Tuple[ParserState, NodeType], ParserState] = {
+    _goto: Final[Dict[Tuple[ParserState, NodeType], ParserState]] = {
         # A string trailer may start with '.' OR '%'.
         (START, token.DOT): DOT,
         (START, token.PERCENT): PERCENT,