X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/00e7e12a3a412ea386806d5d4eeaed345e912940..6debce63bc2429b1680f8838592f2e56e3df6b27:/src/black/trans.py diff --git a/src/black/trans.py b/src/black/trans.py index ca620f6..01aa80e 100644 --- a/src/black/trans.py +++ b/src/black/trans.py @@ -4,10 +4,11 @@ String transformers that can split and merge strings. from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass -import regex as re +import re from typing import ( Any, Callable, + ClassVar, Collection, Dict, Iterable, @@ -15,10 +16,19 @@ from typing import ( List, Optional, Sequence, + Set, Tuple, TypeVar, Union, ) +import sys + +if sys.version_info < (3, 8): + from typing_extensions import Literal, Final +else: + from typing import Literal, Final + +from mypy_extensions import trait from black.rusty import Result, Ok, Err @@ -61,7 +71,88 @@ def TErr(err_msg: str) -> Err[CannotTransform]: return Err(cant_transform) -@dataclass # type: ignore +def hug_power_op(line: Line, features: Collection[Feature]) -> Iterator[Line]: + """A transformer which normalizes spacing around power operators.""" + + # Performance optimization to avoid unnecessary Leaf clones and other ops. + for leaf in line.leaves: + if leaf.type == token.DOUBLESTAR: + break + else: + raise CannotTransform("No doublestar token was found in the line.") + + def is_simple_lookup(index: int, step: Literal[1, -1]) -> bool: + # Brackets and parentheses indicate calls, subscripts, etc. ... + # basically stuff that doesn't count as "simple". Only a NAME lookup + # or dotted lookup (eg. NAME.NAME) is OK. + if step == -1: + disallowed = {token.RPAR, token.RSQB} + else: + disallowed = {token.LPAR, token.LSQB} + + while 0 <= index < len(line.leaves): + current = line.leaves[index] + if current.type in disallowed: + return False + if current.type not in {token.NAME, token.DOT} or current.value == "for": + # If the current token isn't disallowed, we'll assume this is simple as + # only the disallowed tokens are semantically attached to this lookup + # expression we're checking. Also, stop early if we hit the 'for' bit + # of a comprehension. + return True + + index += step + + return True + + def is_simple_operand(index: int, kind: Literal["base", "exponent"]) -> bool: + # An operand is considered "simple" if's a NAME, a numeric CONSTANT, a simple + # lookup (see above), with or without a preceding unary operator. + start = line.leaves[index] + if start.type in {token.NAME, token.NUMBER}: + return is_simple_lookup(index, step=(1 if kind == "exponent" else -1)) + + if start.type in {token.PLUS, token.MINUS, token.TILDE}: + if line.leaves[index + 1].type in {token.NAME, token.NUMBER}: + # step is always one as bases with a preceding unary op will be checked + # for simplicity starting from the next token (so it'll hit the check + # above). + return is_simple_lookup(index + 1, step=1) + + return False + + leaves: List[Leaf] = [] + should_hug = False + for idx, leaf in enumerate(line.leaves): + new_leaf = leaf.clone() + if should_hug: + new_leaf.prefix = "" + should_hug = False + + should_hug = ( + (0 < idx < len(line.leaves) - 1) + and leaf.type == token.DOUBLESTAR + and is_simple_operand(idx - 1, kind="base") + and line.leaves[idx - 1].value != "lambda" + and is_simple_operand(idx + 1, kind="exponent") + ) + if should_hug: + new_leaf.prefix = "" + + leaves.append(new_leaf) + + yield Line( + mode=line.mode, + depth=line.depth, + leaves=leaves, + comments=line.comments, + bracket_tracker=line.bracket_tracker, + inside_brackets=line.inside_brackets, + should_split_rhs=line.should_split_rhs, + magic_trailing_comma=line.magic_trailing_comma, + ) + + class StringTransformer(ABC): """ An implementation of the Transformer protocol that relies on its @@ -89,9 +180,13 @@ class StringTransformer(ABC): as much as possible. """ - line_length: int - normalize_strings: bool - __name__ = "StringTransformer" + __name__: Final = "StringTransformer" + + # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with + # `abc.ABC`. + def __init__(self, line_length: int, normalize_strings: bool) -> None: + self.line_length = line_length + self.normalize_strings = normalize_strings @abstractmethod def do_match(self, line: Line) -> TMatchResult: @@ -183,6 +278,7 @@ class CustomSplit: break_idx: int +@trait class CustomSplitMapMixin: """ This mixin class is used to map merged strings to a sequence of @@ -190,8 +286,10 @@ class CustomSplitMapMixin: the resultant substrings go over the configured max line length. """ - _Key = Tuple[StringID, str] - _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple) + _Key: ClassVar = Tuple[StringID, str] + _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict( + tuple + ) @staticmethod def _get_key(string: str) -> "CustomSplitMapMixin._Key": @@ -242,7 +340,7 @@ class CustomSplitMapMixin: return key in self._CUSTOM_SPLIT_MAP -class StringMerger(CustomSplitMapMixin, StringTransformer): +class StringMerger(StringTransformer, CustomSplitMapMixin): """StringTransformer that merges strings together. Requirements: @@ -267,7 +365,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): is_valid_index = is_valid_index_factory(LL) - for (i, leaf) in enumerate(LL): + for i, leaf in enumerate(LL): if ( leaf.type == token.STRING and is_valid_index(i + 1) @@ -437,7 +535,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): # with 'f'... if "f" in prefix and "f" not in next_prefix: # Then we must escape any braces contained in this substring. - SS = re.subf(r"(\{|\})", "{1}{1}", SS) + SS = re.sub(r"(\{|\})", r"\1\1", SS) NSS = make_naked(SS, next_prefix) @@ -472,7 +570,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): # Build the final line ('new_line') that this method will later return. new_line = line.clone() - for (i, leaf) in enumerate(LL): + for i, leaf in enumerate(LL): if i == string_idx: new_line.append(string_leaf) @@ -593,7 +691,7 @@ class StringParenStripper(StringTransformer): is_valid_index = is_valid_index_factory(LL) - for (idx, leaf) in enumerate(LL): + for idx, leaf in enumerate(LL): # Should be a string... if leaf.type != token.STRING: continue @@ -738,7 +836,7 @@ class BaseStringSplitter(StringTransformer): * The target string is not a multiline (i.e. triple-quote) string. """ - STRING_OPERATORS = [ + STRING_OPERATORS: Final = [ token.EQEQUAL, token.GREATER, token.GREATEREQUAL, @@ -926,7 +1024,58 @@ class BaseStringSplitter(StringTransformer): return max_string_length -class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): +def iter_fexpr_spans(s: str) -> Iterator[Tuple[int, int]]: + """ + Yields spans corresponding to expressions in a given f-string. + Spans are half-open ranges (left inclusive, right exclusive). + Assumes the input string is a valid f-string, but will not crash if the input + string is invalid. + """ + stack: List[int] = [] # our curly paren stack + i = 0 + while i < len(s): + if s[i] == "{": + # if we're in a string part of the f-string, ignore escaped curly braces + if not stack and i + 1 < len(s) and s[i + 1] == "{": + i += 2 + continue + stack.append(i) + i += 1 + continue + + if s[i] == "}": + if not stack: + i += 1 + continue + j = stack.pop() + # we've made it back out of the expression! yield the span + if not stack: + yield (j, i + 1) + i += 1 + continue + + # if we're in an expression part of the f-string, fast forward through strings + # note that backslashes are not legal in the expression portion of f-strings + if stack: + delim = None + if s[i : i + 3] in ("'''", '"""'): + delim = s[i : i + 3] + elif s[i] in ("'", '"'): + delim = s[i] + if delim: + i += len(delim) + while i < len(s) and s[i : i + len(delim)] != delim: + i += 1 + i += len(delim) + continue + i += 1 + + +def fstring_contains_expr(s: str) -> bool: + return any(iter_fexpr_spans(s)) + + +class StringSplitter(BaseStringSplitter, CustomSplitMapMixin): """ StringTransformer that splits "atom" strings (i.e. strings which exist on lines by themselves). @@ -964,18 +1113,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): CustomSplit objects and add them to the custom split map. """ - MIN_SUBSTR_SIZE = 6 - # Matches an "f-expression" (e.g. {var}) that might be found in an f-string. - RE_FEXPR = r""" - (? TMatchResult: LL = line.leaves @@ -1042,8 +1180,8 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): # contain any f-expressions, but ONLY if the original f-string # contains at least one f-expression. Otherwise, we will alter the AST # of the program. - drop_pointless_f_prefix = ("f" in prefix) and re.search( - self.RE_FEXPR, LL[string_idx].value, re.VERBOSE + drop_pointless_f_prefix = ("f" in prefix) and fstring_contains_expr( + LL[string_idx].value ) first_string_line = True @@ -1243,6 +1381,59 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): last_line.comments = line.comments.copy() yield Ok(last_line) + def _iter_nameescape_slices(self, string: str) -> Iterator[Tuple[Index, Index]]: + """ + Yields: + All ranges of @string which, if @string were to be split there, + would result in the splitting of an \\N{...} expression (which is NOT + allowed). + """ + # True - the previous backslash was unescaped + # False - the previous backslash was escaped *or* there was no backslash + previous_was_unescaped_backslash = False + it = iter(enumerate(string)) + for idx, c in it: + if c == "\\": + previous_was_unescaped_backslash = not previous_was_unescaped_backslash + continue + if not previous_was_unescaped_backslash or c != "N": + previous_was_unescaped_backslash = False + continue + previous_was_unescaped_backslash = False + + begin = idx - 1 # the position of backslash before \N{...} + for idx, c in it: + if c == "}": + end = idx + break + else: + # malformed nameescape expression? + # should have been detected by AST parsing earlier... + raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!") + yield begin, end + + def _iter_fexpr_slices(self, string: str) -> Iterator[Tuple[Index, Index]]: + """ + Yields: + All ranges of @string which, if @string were to be split there, + would result in the splitting of an f-expression (which is NOT + allowed). + """ + if "f" not in get_string_prefix(string).lower(): + return + yield from iter_fexpr_spans(string) + + def _get_illegal_split_indices(self, string: str) -> Set[Index]: + illegal_indices: Set[Index] = set() + iterators = [ + self._iter_fexpr_slices(string), + self._iter_nameescape_slices(string), + ] + for it in iterators: + for begin, end in it: + illegal_indices.update(range(begin, end + 1)) + return illegal_indices + def _get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]: """ This method contains the algorithm that StringSplitter uses to @@ -1272,40 +1463,15 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): assert is_valid_index(max_break_idx) assert_is_leaf_string(string) - _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None + _illegal_split_indices = self._get_illegal_split_indices(string) - def fexpr_slices() -> Iterator[Tuple[Index, Index]]: - """ - Yields: - All ranges of @string which, if @string were to be split there, - would result in the splitting of an f-expression (which is NOT - allowed). - """ - nonlocal _fexpr_slices - - if _fexpr_slices is None: - _fexpr_slices = [] - for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE): - _fexpr_slices.append(match.span()) - - yield from _fexpr_slices - - is_fstring = "f" in get_string_prefix(string).lower() - - def breaks_fstring_expression(i: Index) -> bool: + def breaks_unsplittable_expression(i: Index) -> bool: """ Returns: True iff returning @i would result in the splitting of an - f-expression (which is NOT allowed). + unsplittable expression (which is NOT allowed). """ - if not is_fstring: - return False - - for (start, end) in fexpr_slices(): - if start <= i < end: - return True - - return False + return i in _illegal_split_indices def passes_all_checks(i: Index) -> bool: """ @@ -1329,7 +1495,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): is_space and is_not_escaped and is_big_enough - and not breaks_fstring_expression(i) + and not breaks_unsplittable_expression(i) ) # First, we check all indices BELOW @max_break_idx. @@ -1371,7 +1537,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): """ assert_is_leaf_string(string) - if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE): + if "f" in prefix and not fstring_contains_expr(string): new_prefix = prefix.replace("f", "") temp = string[len(prefix) :] @@ -1395,7 +1561,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): return string_op_leaves -class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): +class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin): """ StringTransformer that splits non-"atom" strings (i.e. strings that do not exist on lines by themselves). @@ -1547,7 +1713,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert": is_valid_index = is_valid_index_factory(LL) - for (i, leaf) in enumerate(LL): + for i, leaf in enumerate(LL): # We MUST find a comma... if leaf.type == token.COMMA: idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1 @@ -1585,7 +1751,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): ): is_valid_index = is_valid_index_factory(LL) - for (i, leaf) in enumerate(LL): + for i, leaf in enumerate(LL): # We MUST find either an '=' or '+=' symbol... if leaf.type in [token.EQUAL, token.PLUSEQUAL]: idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1 @@ -1628,7 +1794,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]: is_valid_index = is_valid_index_factory(LL) - for (i, leaf) in enumerate(LL): + for i, leaf in enumerate(LL): # We MUST find a colon... if leaf.type == token.COLON: idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1 @@ -1780,20 +1946,20 @@ class StringParser: ``` """ - DEFAULT_TOKEN = -1 + DEFAULT_TOKEN: Final = 20210605 # String Parser States - START = 1 - DOT = 2 - NAME = 3 - PERCENT = 4 - SINGLE_FMT_ARG = 5 - LPAR = 6 - RPAR = 7 - DONE = 8 + START: Final = 1 + DOT: Final = 2 + NAME: Final = 3 + PERCENT: Final = 4 + SINGLE_FMT_ARG: Final = 5 + LPAR: Final = 6 + RPAR: Final = 7 + DONE: Final = 8 # Lookup Table for Next State - _goto: Dict[Tuple[ParserState, NodeType], ParserState] = { + _goto: Final[Dict[Tuple[ParserState, NodeType], ParserState]] = { # A string trailer may start with '.' OR '%'. (START, token.DOT): DOT, (START, token.PERCENT): PERCENT,