X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/00e7e12a3a412ea386806d5d4eeaed345e912940..7aa37ea0adf864baf3ef3dfbcfaf5ff1ff780250:/src/black/trans.py?ds=sidebyside diff --git a/src/black/trans.py b/src/black/trans.py index ca620f6..c0cc926 100644 --- a/src/black/trans.py +++ b/src/black/trans.py @@ -1,39 +1,57 @@ """ String transformers that can split and merge strings. """ +import re from abc import ABC, abstractmethod from collections import defaultdict from dataclasses import dataclass -import regex as re from typing import ( Any, Callable, + ClassVar, Collection, Dict, + Final, Iterable, Iterator, List, + Literal, Optional, Sequence, + Set, Tuple, TypeVar, Union, ) -from black.rusty import Result, Ok, Err +from mypy_extensions import trait -from black.mode import Feature -from black.nodes import syms, replace_child, parent_type -from black.nodes import is_empty_par, is_empty_lpar, is_empty_rpar -from black.nodes import OPENING_BRACKETS, CLOSING_BRACKETS, STANDALONE_COMMENT -from black.lines import Line, append_leaves -from black.brackets import BracketMatchError from black.comments import contains_pragma_comment -from black.strings import has_triple_quotes, get_string_prefix, assert_is_leaf_string -from black.strings import normalize_string_quotes - -from blib2to3.pytree import Leaf, Node +from black.lines import Line, append_leaves +from black.mode import Feature, Mode +from black.nodes import ( + CLOSING_BRACKETS, + OPENING_BRACKETS, + STANDALONE_COMMENT, + is_empty_lpar, + is_empty_par, + is_empty_rpar, + is_part_of_annotation, + parent_type, + replace_child, + syms, +) +from black.rusty import Err, Ok, Result +from black.strings import ( + assert_is_leaf_string, + count_chars_in_width, + get_string_prefix, + has_triple_quotes, + normalize_string_quotes, + str_width, +) from blib2to3.pgen2 import token +from blib2to3.pytree import Leaf, Node class CannotTransform(Exception): @@ -43,13 +61,15 @@ class CannotTransform(Exception): # types T = TypeVar("T") LN = Union[Leaf, Node] -Transformer = Callable[[Line, Collection[Feature]], Iterator[Line]] +Transformer = Callable[[Line, Collection[Feature], Mode], Iterator[Line]] Index = int NodeType = int ParserState = int StringID = int TResult = Result[T, CannotTransform] # (T)ransform Result -TMatchResult = TResult[Index] +TMatchResult = TResult[List[Index]] + +SPLIT_SAFE_CHARS = frozenset(["\u3001", "\u3002", "\uff0c"]) # East Asian stops def TErr(err_msg: str) -> Err[CannotTransform]: @@ -61,7 +81,86 @@ def TErr(err_msg: str) -> Err[CannotTransform]: return Err(cant_transform) -@dataclass # type: ignore +def hug_power_op( + line: Line, features: Collection[Feature], mode: Mode +) -> Iterator[Line]: + """A transformer which normalizes spacing around power operators.""" + + # Performance optimization to avoid unnecessary Leaf clones and other ops. + for leaf in line.leaves: + if leaf.type == token.DOUBLESTAR: + break + else: + raise CannotTransform("No doublestar token was found in the line.") + + def is_simple_lookup(index: int, step: Literal[1, -1]) -> bool: + # Brackets and parentheses indicate calls, subscripts, etc. ... + # basically stuff that doesn't count as "simple". Only a NAME lookup + # or dotted lookup (eg. NAME.NAME) is OK. + if step == -1: + disallowed = {token.RPAR, token.RSQB} + else: + disallowed = {token.LPAR, token.LSQB} + + while 0 <= index < len(line.leaves): + current = line.leaves[index] + if current.type in disallowed: + return False + if current.type not in {token.NAME, token.DOT} or current.value == "for": + # If the current token isn't disallowed, we'll assume this is simple as + # only the disallowed tokens are semantically attached to this lookup + # expression we're checking. Also, stop early if we hit the 'for' bit + # of a comprehension. + return True + + index += step + + return True + + def is_simple_operand(index: int, kind: Literal["base", "exponent"]) -> bool: + # An operand is considered "simple" if's a NAME, a numeric CONSTANT, a simple + # lookup (see above), with or without a preceding unary operator. + start = line.leaves[index] + if start.type in {token.NAME, token.NUMBER}: + return is_simple_lookup(index, step=(1 if kind == "exponent" else -1)) + + if start.type in {token.PLUS, token.MINUS, token.TILDE}: + if line.leaves[index + 1].type in {token.NAME, token.NUMBER}: + # step is always one as bases with a preceding unary op will be checked + # for simplicity starting from the next token (so it'll hit the check + # above). + return is_simple_lookup(index + 1, step=1) + + return False + + new_line = line.clone() + should_hug = False + for idx, leaf in enumerate(line.leaves): + new_leaf = leaf.clone() + if should_hug: + new_leaf.prefix = "" + should_hug = False + + should_hug = ( + (0 < idx < len(line.leaves) - 1) + and leaf.type == token.DOUBLESTAR + and is_simple_operand(idx - 1, kind="base") + and line.leaves[idx - 1].value != "lambda" + and is_simple_operand(idx + 1, kind="exponent") + ) + if should_hug: + new_leaf.prefix = "" + + # We have to be careful to make a new line properly: + # - bracket related metadata must be maintained (handled by Line.append) + # - comments need to copied over, updating the leaf IDs they're attached to + new_line.append(new_leaf, preformatted=True) + for comment_leaf in line.comments_after(leaf): + new_line.append(comment_leaf, preformatted=True) + + yield new_line + + class StringTransformer(ABC): """ An implementation of the Transformer protocol that relies on its @@ -89,31 +188,40 @@ class StringTransformer(ABC): as much as possible. """ - line_length: int - normalize_strings: bool - __name__ = "StringTransformer" + __name__: Final = "StringTransformer" + + # Ideally this would be a dataclass, but unfortunately mypyc breaks when used with + # `abc.ABC`. + def __init__(self, line_length: int, normalize_strings: bool) -> None: + self.line_length = line_length + self.normalize_strings = normalize_strings @abstractmethod def do_match(self, line: Line) -> TMatchResult: """ Returns: - * Ok(string_idx) such that `line.leaves[string_idx]` is our target - string, if a match was able to be made. - OR - * Err(CannotTransform), if a match was not able to be made. + * Ok(string_indices) such that for each index, `line.leaves[index]` + is our target string if a match was able to be made. For + transformers that don't result in more lines (e.g. StringMerger, + StringParenStripper), multiple matches and transforms are done at + once to reduce the complexity. + OR + * Err(CannotTransform), if no match could be made. """ @abstractmethod - def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + def do_transform( + self, line: Line, string_indices: List[int] + ) -> Iterator[TResult[Line]]: """ Yields: * Ok(new_line) where new_line is the new transformed line. - OR + OR * Err(CannotTransform) if the transformation failed for some reason. The - `do_match(...)` template method should usually be used to reject - the form of the given Line, but in some cases it is difficult to - know whether or not a Line meets the StringTransformer's - requirements until the transformation is already midway. + `do_match(...)` template method should usually be used to reject + the form of the given Line, but in some cases it is difficult to + know whether or not a Line meets the StringTransformer's + requirements until the transformation is already midway. Side Effects: This method should NOT mutate @line directly, but it MAY mutate the @@ -122,7 +230,9 @@ class StringTransformer(ABC): yield an CannotTransform after that point.) """ - def __call__(self, line: Line, _features: Collection[Feature]) -> Iterator[Line]: + def __call__( + self, line: Line, _features: Collection[Feature], _mode: Mode + ) -> Iterator[Line]: """ StringTransformer instances have a call signature that mirrors that of the Transformer type. @@ -145,9 +255,9 @@ class StringTransformer(ABC): " this line as one that it can transform." ) from cant_transform - string_idx = match_result.ok() + string_indices = match_result.ok() - for line_result in self.do_transform(line, string_idx): + for line_result in self.do_transform(line, string_indices): if isinstance(line_result, Err): cant_transform = line_result.err() raise CannotTransform( @@ -183,6 +293,7 @@ class CustomSplit: break_idx: int +@trait class CustomSplitMapMixin: """ This mixin class is used to map merged strings to a sequence of @@ -190,8 +301,10 @@ class CustomSplitMapMixin: the resultant substrings go over the configured max line length. """ - _Key = Tuple[StringID, str] - _CUSTOM_SPLIT_MAP: Dict[_Key, Tuple[CustomSplit, ...]] = defaultdict(tuple) + _Key: ClassVar = Tuple[StringID, str] + _CUSTOM_SPLIT_MAP: ClassVar[Dict[_Key, Tuple[CustomSplit, ...]]] = defaultdict( + tuple + ) @staticmethod def _get_key(string: str) -> "CustomSplitMapMixin._Key": @@ -218,8 +331,8 @@ class CustomSplitMapMixin: Returns: * A list of the custom splits that are mapped to @string, if any - exist. - OR + exist. + OR * [], otherwise. Side Effects: @@ -242,20 +355,20 @@ class CustomSplitMapMixin: return key in self._CUSTOM_SPLIT_MAP -class StringMerger(CustomSplitMapMixin, StringTransformer): +class StringMerger(StringTransformer, CustomSplitMapMixin): """StringTransformer that merges strings together. Requirements: (A) The line contains adjacent strings such that ALL of the validation checks - listed in StringMerger.__validate_msg(...)'s docstring pass. - OR + listed in StringMerger._validate_msg(...)'s docstring pass. + OR (B) The line contains a string which uses line continuation backslashes. Transformations: Depending on which of the two requirements above where met, either: (A) The string group associated with the target string is merged. - OR + OR (B) All line-continuation backslashes are removed from the target string. Collaborations: @@ -267,28 +380,50 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): is_valid_index = is_valid_index_factory(LL) - for (i, leaf) in enumerate(LL): + string_indices = [] + idx = 0 + while is_valid_index(idx): + leaf = LL[idx] if ( leaf.type == token.STRING - and is_valid_index(i + 1) - and LL[i + 1].type == token.STRING + and is_valid_index(idx + 1) + and LL[idx + 1].type == token.STRING ): - return Ok(i) + if not is_part_of_annotation(leaf): + string_indices.append(idx) - if leaf.type == token.STRING and "\\\n" in leaf.value: - return Ok(i) + # Advance to the next non-STRING leaf. + idx += 2 + while is_valid_index(idx) and LL[idx].type == token.STRING: + idx += 1 - return TErr("This line has no strings that need merging.") + elif leaf.type == token.STRING and "\\\n" in leaf.value: + string_indices.append(idx) + # Advance to the next non-STRING leaf. + idx += 1 + while is_valid_index(idx) and LL[idx].type == token.STRING: + idx += 1 - def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + else: + idx += 1 + + if string_indices: + return Ok(string_indices) + else: + return TErr("This line has no strings that need merging.") + + def do_transform( + self, line: Line, string_indices: List[int] + ) -> Iterator[TResult[Line]]: new_line = line + rblc_result = self._remove_backslash_line_continuation_chars( - new_line, string_idx + new_line, string_indices ) if isinstance(rblc_result, Ok): new_line = rblc_result.ok() - msg_result = self._merge_string_group(new_line, string_idx) + msg_result = self._merge_string_group(new_line, string_indices) if isinstance(msg_result, Ok): new_line = msg_result.ok() @@ -309,7 +444,7 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): @staticmethod def _remove_backslash_line_continuation_chars( - line: Line, string_idx: int + line: Line, string_indices: List[int] ) -> TResult[Line]: """ Merge strings that were split across multiple lines using @@ -323,34 +458,44 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): """ LL = line.leaves - string_leaf = LL[string_idx] - if not ( - string_leaf.type == token.STRING - and "\\\n" in string_leaf.value - and not has_triple_quotes(string_leaf.value) - ): + indices_to_transform = [] + for string_idx in string_indices: + string_leaf = LL[string_idx] + if ( + string_leaf.type == token.STRING + and "\\\n" in string_leaf.value + and not has_triple_quotes(string_leaf.value) + ): + indices_to_transform.append(string_idx) + + if not indices_to_transform: return TErr( - f"String leaf {string_leaf} does not contain any backslash line" - " continuation characters." + "Found no string leaves that contain backslash line continuation" + " characters." ) new_line = line.clone() new_line.comments = line.comments.copy() append_leaves(new_line, line, LL) - new_string_leaf = new_line.leaves[string_idx] - new_string_leaf.value = new_string_leaf.value.replace("\\\n", "") + for string_idx in indices_to_transform: + new_string_leaf = new_line.leaves[string_idx] + new_string_leaf.value = new_string_leaf.value.replace("\\\n", "") return Ok(new_line) - def _merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]: + def _merge_string_group( + self, line: Line, string_indices: List[int] + ) -> TResult[Line]: """ - Merges string group (i.e. set of adjacent strings) where the first - string in the group is `line.leaves[string_idx]`. + Merges string groups (i.e. set of adjacent strings). + + Each index from `string_indices` designates one string group's first + leaf in `line.leaves`. Returns: Ok(new_line), if ALL of the validation checks found in - __validate_msg(...) pass. + _validate_msg(...) pass. OR Err(CannotTransform), otherwise. """ @@ -358,10 +503,54 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): is_valid_index = is_valid_index_factory(LL) - vresult = self._validate_msg(line, string_idx) - if isinstance(vresult, Err): - return vresult + # A dict of {string_idx: tuple[num_of_strings, string_leaf]}. + merged_string_idx_dict: Dict[int, Tuple[int, Leaf]] = {} + for string_idx in string_indices: + vresult = self._validate_msg(line, string_idx) + if isinstance(vresult, Err): + continue + merged_string_idx_dict[string_idx] = self._merge_one_string_group( + LL, string_idx, is_valid_index + ) + + if not merged_string_idx_dict: + return TErr("No string group is merged") + + # Build the final line ('new_line') that this method will later return. + new_line = line.clone() + previous_merged_string_idx = -1 + previous_merged_num_of_strings = -1 + for i, leaf in enumerate(LL): + if i in merged_string_idx_dict: + previous_merged_string_idx = i + previous_merged_num_of_strings, string_leaf = merged_string_idx_dict[i] + new_line.append(string_leaf) + if ( + previous_merged_string_idx + <= i + < previous_merged_string_idx + previous_merged_num_of_strings + ): + for comment_leaf in line.comments_after(LL[i]): + new_line.append(comment_leaf, preformatted=True) + continue + + append_leaves(new_line, line, [leaf]) + + return Ok(new_line) + + def _merge_one_string_group( + self, LL: List[Leaf], string_idx: int, is_valid_index: Callable[[int], bool] + ) -> Tuple[int, Leaf]: + """ + Merges one string group where the first string in the group is + `LL[string_idx]`. + + Returns: + A tuple of `(num_of_strings, leaf)` where `num_of_strings` is the + number of strings merged and `leaf` is the newly merged string + to be replaced in the new line. + """ # If the string group is wrapped inside an Atom node, we must make sure # to later replace that Atom with our new (merged) string leaf. atom_node = LL[string_idx].parent @@ -387,6 +576,12 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): characters have been escaped. """ assert_is_leaf_string(string) + if "f" in string_prefix: + string = _toggle_fexpr_quotes(string, QUOTE) + # After quotes toggling, quotes in expressions won't be escaped + # because quotes can't be reused in f-strings. So we can simply + # let the escaping logic below run without knowing f-string + # expressions. RE_EVEN_BACKSLASHES = r"(?:(? TResult[None]: """Validate (M)erge (S)tring (G)roup - Transform-time string validation logic for __merge_string_group(...). + Transform-time string validation logic for _merge_string_group(...). Returns: * Ok(None), if ALL validation checks (listed below) pass. @@ -504,6 +700,11 @@ class StringMerger(CustomSplitMapMixin, StringTransformer): - The set of all string prefixes in the string group is of length greater than one and is not equal to {"", "f"}. - The string group consists of raw strings. + - The string group is stringified type annotations. We don't want to + process stringified type annotations since pyright doesn't support + them spanning multiple string values. (NOTE: mypy, pytype, pyre do + support them, so we can change if pyright also gains support in the + future. See https://github.com/microsoft/pyright/issues/4359.) """ # We first check for "inner" stand-alone comments (i.e. stand-alone # comments that have a string leaf before them AND after them). @@ -593,7 +794,15 @@ class StringParenStripper(StringTransformer): is_valid_index = is_valid_index_factory(LL) - for (idx, leaf) in enumerate(LL): + string_indices = [] + + idx = -1 + while True: + idx += 1 + if idx >= len(LL): + break + leaf = LL[idx] + # Should be a string... if leaf.type != token.STRING: continue @@ -675,45 +884,76 @@ class StringParenStripper(StringTransformer): }: continue - return Ok(string_idx) + string_indices.append(string_idx) + idx = string_idx + while idx < len(LL) - 1 and LL[idx + 1].type == token.STRING: + idx += 1 + if string_indices: + return Ok(string_indices) return TErr("This line has no strings wrapped in parens.") - def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + def do_transform( + self, line: Line, string_indices: List[int] + ) -> Iterator[TResult[Line]]: LL = line.leaves - string_parser = StringParser() - rpar_idx = string_parser.parse(LL, string_idx) + string_and_rpar_indices: List[int] = [] + for string_idx in string_indices: + string_parser = StringParser() + rpar_idx = string_parser.parse(LL, string_idx) + + should_transform = True + for leaf in (LL[string_idx - 1], LL[rpar_idx]): + if line.comments_after(leaf): + # Should not strip parentheses which have comments attached + # to them. + should_transform = False + break + if should_transform: + string_and_rpar_indices.extend((string_idx, rpar_idx)) - for leaf in (LL[string_idx - 1], LL[rpar_idx]): - if line.comments_after(leaf): - yield TErr( - "Will not strip parentheses which have comments attached to them." - ) - return + if string_and_rpar_indices: + yield Ok(self._transform_to_new_line(line, string_and_rpar_indices)) + else: + yield Err( + CannotTransform("All string groups have comments attached to them.") + ) + + def _transform_to_new_line( + self, line: Line, string_and_rpar_indices: List[int] + ) -> Line: + LL = line.leaves new_line = line.clone() new_line.comments = line.comments.copy() - try: - append_leaves(new_line, line, LL[: string_idx - 1]) - except BracketMatchError: - # HACK: I believe there is currently a bug somewhere in - # right_hand_split() that is causing brackets to not be tracked - # properly by a shared BracketTracker. - append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True) - - string_leaf = Leaf(token.STRING, LL[string_idx].value) - LL[string_idx - 1].remove() - replace_child(LL[string_idx], string_leaf) - new_line.append(string_leaf) - - append_leaves( - new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :] - ) - LL[rpar_idx].remove() + previous_idx = -1 + # We need to sort the indices, since string_idx and its matching + # rpar_idx may not come in order, e.g. in + # `("outer" % ("inner".join(items)))`, the "inner" string's + # string_idx is smaller than "outer" string's rpar_idx. + for idx in sorted(string_and_rpar_indices): + leaf = LL[idx] + lpar_or_rpar_idx = idx - 1 if leaf.type == token.STRING else idx + append_leaves(new_line, line, LL[previous_idx + 1 : lpar_or_rpar_idx]) + if leaf.type == token.STRING: + string_leaf = Leaf(token.STRING, LL[idx].value) + LL[lpar_or_rpar_idx].remove() # Remove lpar. + replace_child(LL[idx], string_leaf) + new_line.append(string_leaf) + # replace comments + old_comments = new_line.comments.pop(id(LL[idx]), []) + new_line.comments.setdefault(id(string_leaf), []).extend(old_comments) + else: + LL[lpar_or_rpar_idx].remove() # This is a rpar. - yield Ok(new_line) + previous_idx = idx + + # Append the leaves after the last idx: + append_leaves(new_line, line, LL[idx + 1 :]) + + return new_line class BaseStringSplitter(StringTransformer): @@ -724,21 +964,24 @@ class BaseStringSplitter(StringTransformer): Requirements: * The target string value is responsible for the line going over the - line length limit. It follows that after all of black's other line - split methods have been exhausted, this line (or one of the resulting - lines after all line splits are performed) would still be over the - line_length limit unless we split this string. - AND + line length limit. It follows that after all of black's other line + split methods have been exhausted, this line (or one of the resulting + lines after all line splits are performed) would still be over the + line_length limit unless we split this string. + AND + * The target string is NOT a "pointless" string (i.e. a string that has - no parent or siblings). - AND + no parent or siblings). + AND + * The target string is not followed by an inline comment that appears - to be a pragma. - AND + to be a pragma. + AND + * The target string is not a multiline (i.e. triple-quote) string. """ - STRING_OPERATORS = [ + STRING_OPERATORS: Final = [ token.EQEQUAL, token.GREATER, token.GREATEREQUAL, @@ -766,7 +1009,12 @@ class BaseStringSplitter(StringTransformer): if isinstance(match_result, Err): return match_result - string_idx = match_result.ok() + string_indices = match_result.ok() + assert len(string_indices) == 1, ( + f"{self.__class__.__name__} should only find one match at a time, found" + f" {len(string_indices)}" + ) + string_idx = string_indices[0] vresult = self._validate(line, string_idx) if isinstance(vresult, Err): return vresult @@ -781,7 +1029,7 @@ class BaseStringSplitter(StringTransformer): Returns: * Ok(None), if ALL of the requirements are met. - OR + OR * Err(CannotTransform), if ANY of the requirements are NOT met. """ LL = line.leaves @@ -922,20 +1170,140 @@ class BaseStringSplitter(StringTransformer): # WMA4 the length of the inline comment. offset += len(comment_leaf.value) - max_string_length = self.line_length - offset + max_string_length = count_chars_in_width(str(line), self.line_length - offset) return max_string_length + @staticmethod + def _prefer_paren_wrap_match(LL: List[Leaf]) -> Optional[int]: + """ + Returns: + string_idx such that @LL[string_idx] is equal to our target (i.e. + matched) string, if this line matches the "prefer paren wrap" statement + requirements listed in the 'Requirements' section of the StringParenWrapper + class's docstring. + OR + None, otherwise. + """ + # The line must start with a string. + if LL[0].type != token.STRING: + return None + + matching_nodes = [ + syms.listmaker, + syms.dictsetmaker, + syms.testlist_gexp, + ] + # If the string is an immediate child of a list/set/tuple literal... + if ( + parent_type(LL[0]) in matching_nodes + or parent_type(LL[0].parent) in matching_nodes + ): + # And the string is surrounded by commas (or is the first/last child)... + prev_sibling = LL[0].prev_sibling + next_sibling = LL[0].next_sibling + if ( + not prev_sibling + and not next_sibling + and parent_type(LL[0]) == syms.atom + ): + # If it's an atom string, we need to check the parent atom's siblings. + parent = LL[0].parent + assert parent is not None # For type checkers. + prev_sibling = parent.prev_sibling + next_sibling = parent.next_sibling + if (not prev_sibling or prev_sibling.type == token.COMMA) and ( + not next_sibling or next_sibling.type == token.COMMA + ): + return 0 + + return None + + +def iter_fexpr_spans(s: str) -> Iterator[Tuple[int, int]]: + """ + Yields spans corresponding to expressions in a given f-string. + Spans are half-open ranges (left inclusive, right exclusive). + Assumes the input string is a valid f-string, but will not crash if the input + string is invalid. + """ + stack: List[int] = [] # our curly paren stack + i = 0 + while i < len(s): + if s[i] == "{": + # if we're in a string part of the f-string, ignore escaped curly braces + if not stack and i + 1 < len(s) and s[i + 1] == "{": + i += 2 + continue + stack.append(i) + i += 1 + continue + + if s[i] == "}": + if not stack: + i += 1 + continue + j = stack.pop() + # we've made it back out of the expression! yield the span + if not stack: + yield (j, i + 1) + i += 1 + continue + + # if we're in an expression part of the f-string, fast forward through strings + # note that backslashes are not legal in the expression portion of f-strings + if stack: + delim = None + if s[i : i + 3] in ("'''", '"""'): + delim = s[i : i + 3] + elif s[i] in ("'", '"'): + delim = s[i] + if delim: + i += len(delim) + while i < len(s) and s[i : i + len(delim)] != delim: + i += 1 + i += len(delim) + continue + i += 1 + + +def fstring_contains_expr(s: str) -> bool: + return any(iter_fexpr_spans(s)) + + +def _toggle_fexpr_quotes(fstring: str, old_quote: str) -> str: + """ + Toggles quotes used in f-string expressions that are `old_quote`. + + f-string expressions can't contain backslashes, so we need to toggle the + quotes if the f-string itself will end up using the same quote. We can + simply toggle without escaping because, quotes can't be reused in f-string + expressions. They will fail to parse. -class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): + NOTE: If PEP 701 is accepted, above statement will no longer be true. + Though if quotes can be reused, we can simply reuse them without updates or + escaping, once Black figures out how to parse the new grammar. + """ + new_quote = "'" if old_quote == '"' else '"' + parts = [] + previous_index = 0 + for start, end in iter_fexpr_spans(fstring): + parts.append(fstring[previous_index:start]) + parts.append(fstring[start:end].replace(old_quote, new_quote)) + previous_index = end + parts.append(fstring[previous_index:]) + return "".join(parts) + + +class StringSplitter(BaseStringSplitter, CustomSplitMapMixin): """ StringTransformer that splits "atom" strings (i.e. strings which exist on lines by themselves). Requirements: * The line consists ONLY of a single string (possibly prefixed by a - string operator [e.g. '+' or '==']), MAYBE a string trailer, and MAYBE - a trailing comma. - AND + string operator [e.g. '+' or '==']), MAYBE a string trailer, and MAYBE + a trailing comma. + AND * All of the requirements listed in BaseStringSplitter's docstring. Transformations: @@ -964,22 +1332,14 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): CustomSplit objects and add them to the custom split map. """ - MIN_SUBSTR_SIZE = 6 - # Matches an "f-expression" (e.g. {var}) that might be found in an f-string. - RE_FEXPR = r""" - (? TMatchResult: LL = line.leaves + if self._prefer_paren_wrap_match(LL) is not None: + return TErr("Line needs to be wrapped in parens first.") + is_valid_index = is_valid_index_factory(LL) idx = 0 @@ -1026,10 +1386,17 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): if is_valid_index(idx): return TErr("This line does not end with a string.") - return Ok(string_idx) + return Ok([string_idx]) - def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + def do_transform( + self, line: Line, string_indices: List[int] + ) -> Iterator[TResult[Line]]: LL = line.leaves + assert len(string_indices) == 1, ( + f"{self.__class__.__name__} should only find one match at a time, found" + f" {len(string_indices)}" + ) + string_idx = string_indices[0] QUOTE = LL[string_idx].value[-1] @@ -1042,15 +1409,15 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): # contain any f-expressions, but ONLY if the original f-string # contains at least one f-expression. Otherwise, we will alter the AST # of the program. - drop_pointless_f_prefix = ("f" in prefix) and re.search( - self.RE_FEXPR, LL[string_idx].value, re.VERBOSE + drop_pointless_f_prefix = ("f" in prefix) and fstring_contains_expr( + LL[string_idx].value ) first_string_line = True string_op_leaves = self._get_string_operator_leaves(LL) string_op_leaves_length = ( - sum([len(str(prefix_leaf)) for prefix_leaf in string_op_leaves]) + 1 + sum(len(str(prefix_leaf)) for prefix_leaf in string_op_leaves) + 1 if string_op_leaves else 0 ) @@ -1072,11 +1439,13 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): is_valid_index(string_idx + 1) and LL[string_idx + 1].type == token.COMMA ) - def max_last_string() -> int: + def max_last_string_column() -> int: """ Returns: - The max allowed length of the string value used for the last - line we will construct. + The max allowed width of the string value used for the last + line we will construct. Note that this value means the width + rather than the number of characters (e.g., many East Asian + characters expand to two columns). """ result = self.line_length result -= line.depth * 4 @@ -1084,14 +1453,14 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): result -= string_op_leaves_length return result - # --- Calculate Max Break Index (for string value) + # --- Calculate Max Break Width (for string value) # We start with the line length limit - max_break_idx = self.line_length + max_break_width = self.line_length # The last index of a string of length N is N-1. - max_break_idx -= 1 + max_break_width -= 1 # Leading whitespace is not present in the string value (e.g. Leaf.value). - max_break_idx -= line.depth * 4 - if max_break_idx < 0: + max_break_width -= line.depth * 4 + if max_break_width < 0: yield TErr( f"Unable to split {LL[string_idx].value} at such high of a line depth:" f" {line.depth}" @@ -1104,7 +1473,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): # line limit. use_custom_breakpoints = bool( custom_splits - and all(csplit.break_idx <= max_break_idx for csplit in custom_splits) + and all(csplit.break_idx <= max_break_width for csplit in custom_splits) ) # Temporary storage for the remaining chunk of the string line that @@ -1120,7 +1489,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): if use_custom_breakpoints: return len(custom_splits) > 1 else: - return len(rest_value) > max_last_string() + return str_width(rest_value) > max_last_string_column() string_line_results: List[Ok[Line]] = [] while more_splits_should_be_made(): @@ -1130,7 +1499,10 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): break_idx = csplit.break_idx else: # Algorithmic Split (automatic) - max_bidx = max_break_idx - string_op_leaves_length + max_bidx = ( + count_chars_in_width(rest_value, max_break_width) + - string_op_leaves_length + ) maybe_break_idx = self._get_break_idx(rest_value, max_bidx) if maybe_break_idx is None: # If we are unable to algorithmically determine a good split @@ -1167,9 +1539,14 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): # prefix, and the current custom split did NOT originally use a # prefix... if ( - next_value != self._normalize_f_string(next_value, prefix) - and use_custom_breakpoints + use_custom_breakpoints and not csplit.has_prefix + and ( + # `next_value == prefix + QUOTE` happens when the custom + # split is an empty string. + next_value == prefix + QUOTE + or next_value != self._normalize_f_string(next_value, prefix) + ) ): # Then `csplit.break_idx` will be off by one after removing # the 'f' prefix. @@ -1222,7 +1599,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): # Try to fit them all on the same line with the last substring... if ( - len(temp_value) <= max_last_string() + str_width(temp_value) <= max_last_string_column() or LL[string_idx + 1].type == token.COMMA ): last_line.append(rest_leaf) @@ -1243,6 +1620,59 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): last_line.comments = line.comments.copy() yield Ok(last_line) + def _iter_nameescape_slices(self, string: str) -> Iterator[Tuple[Index, Index]]: + """ + Yields: + All ranges of @string which, if @string were to be split there, + would result in the splitting of an \\N{...} expression (which is NOT + allowed). + """ + # True - the previous backslash was unescaped + # False - the previous backslash was escaped *or* there was no backslash + previous_was_unescaped_backslash = False + it = iter(enumerate(string)) + for idx, c in it: + if c == "\\": + previous_was_unescaped_backslash = not previous_was_unescaped_backslash + continue + if not previous_was_unescaped_backslash or c != "N": + previous_was_unescaped_backslash = False + continue + previous_was_unescaped_backslash = False + + begin = idx - 1 # the position of backslash before \N{...} + for idx, c in it: + if c == "}": + end = idx + break + else: + # malformed nameescape expression? + # should have been detected by AST parsing earlier... + raise RuntimeError(f"{self.__class__.__name__} LOGIC ERROR!") + yield begin, end + + def _iter_fexpr_slices(self, string: str) -> Iterator[Tuple[Index, Index]]: + """ + Yields: + All ranges of @string which, if @string were to be split there, + would result in the splitting of an f-expression (which is NOT + allowed). + """ + if "f" not in get_string_prefix(string).lower(): + return + yield from iter_fexpr_spans(string) + + def _get_illegal_split_indices(self, string: str) -> Set[Index]: + illegal_indices: Set[Index] = set() + iterators = [ + self._iter_fexpr_slices(string), + self._iter_nameescape_slices(string), + ] + for it in iterators: + for begin, end in it: + illegal_indices.update(range(begin, end + 1)) + return illegal_indices + def _get_break_idx(self, string: str, max_break_idx: int) -> Optional[int]: """ This method contains the algorithm that StringSplitter uses to @@ -1272,40 +1702,15 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): assert is_valid_index(max_break_idx) assert_is_leaf_string(string) - _fexpr_slices: Optional[List[Tuple[Index, Index]]] = None + _illegal_split_indices = self._get_illegal_split_indices(string) - def fexpr_slices() -> Iterator[Tuple[Index, Index]]: - """ - Yields: - All ranges of @string which, if @string were to be split there, - would result in the splitting of an f-expression (which is NOT - allowed). - """ - nonlocal _fexpr_slices - - if _fexpr_slices is None: - _fexpr_slices = [] - for match in re.finditer(self.RE_FEXPR, string, re.VERBOSE): - _fexpr_slices.append(match.span()) - - yield from _fexpr_slices - - is_fstring = "f" in get_string_prefix(string).lower() - - def breaks_fstring_expression(i: Index) -> bool: + def breaks_unsplittable_expression(i: Index) -> bool: """ Returns: True iff returning @i would result in the splitting of an - f-expression (which is NOT allowed). + unsplittable expression (which is NOT allowed). """ - if not is_fstring: - return False - - for (start, end) in fexpr_slices(): - if start <= i < end: - return True - - return False + return i in _illegal_split_indices def passes_all_checks(i: Index) -> bool: """ @@ -1314,6 +1719,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): section of this classes' docstring would be be met by returning @i. """ is_space = string[i] == " " + is_split_safe = is_valid_index(i - 1) and string[i - 1] in SPLIT_SAFE_CHARS is_not_escaped = True j = i - 1 @@ -1326,10 +1732,10 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): and len(string[:i]) >= self.MIN_SUBSTR_SIZE ) return ( - is_space + (is_space or is_split_safe) and is_not_escaped and is_big_enough - and not breaks_fstring_expression(i) + and not breaks_unsplittable_expression(i) ) # First, we check all indices BELOW @max_break_idx. @@ -1371,7 +1777,7 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): """ assert_is_leaf_string(string) - if "f" in prefix and not re.search(self.RE_FEXPR, string, re.VERBOSE): + if "f" in prefix and not fstring_contains_expr(string): new_prefix = prefix.replace("f", "") temp = string[len(prefix) :] @@ -1395,29 +1801,35 @@ class StringSplitter(CustomSplitMapMixin, BaseStringSplitter): return string_op_leaves -class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): +class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin): """ - StringTransformer that splits non-"atom" strings (i.e. strings that do not - exist on lines by themselves). + StringTransformer that wraps strings in parens and then splits at the LPAR. Requirements: All of the requirements listed in BaseStringSplitter's docstring in addition to the requirements listed below: * The line is a return/yield statement, which returns/yields a string. - OR + OR * The line is part of a ternary expression (e.g. `x = y if cond else - z`) such that the line starts with `else `, where is - some string. - OR + z`) such that the line starts with `else `, where is + some string. + OR * The line is an assert statement, which ends with a string. - OR + OR * The line is an assignment statement (e.g. `x = ` or `x += - `) such that the variable is being assigned the value of some - string. - OR + `) such that the variable is being assigned the value of some + string. + OR * The line is a dictionary key assignment where some valid key is being - assigned the value of some string. + assigned the value of some string. + OR + * The line is an lambda expression and the value is a string. + OR + * The line starts with an "atom" string that prefers to be wrapped in + parens. It's preferred to be wrapped when it's is an immediate child of + a list/set/tuple literal, AND the string is surrounded by commas (or is + the first/last child). Transformations: The chosen string is wrapped in parentheses and then split at the LPAR. @@ -1442,6 +1854,9 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): changed such that it no longer needs to be given its own line, StringParenWrapper relies on StringParenStripper to clean up the parentheses it created. + + For "atom" strings that prefers to be wrapped in parens, it requires + StringSplitter to hold the split until the string is wrapped in parens. """ def do_splitter_match(self, line: Line) -> TMatchResult: @@ -1457,16 +1872,19 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): or self._else_match(LL) or self._assert_match(LL) or self._assign_match(LL) - or self._dict_match(LL) + or self._dict_or_lambda_match(LL) + or self._prefer_paren_wrap_match(LL) ) if string_idx is not None: string_value = line.leaves[string_idx].value - # If the string has no spaces... - if " " not in string_value: + # If the string has neither spaces nor East Asian stops... + if not any( + char == " " or char in SPLIT_SAFE_CHARS for char in string_value + ): # And will still violate the line length limit when split... - max_string_length = self.line_length - ((line.depth + 1) * 4) - if len(string_value) > max_string_length: + max_string_width = self.line_length - ((line.depth + 1) * 4) + if str_width(string_value) > max_string_width: # And has no associated custom splits... if not self.has_custom_splits(string_value): # Then we should NOT put this string on its own line. @@ -1475,7 +1893,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): " resultant line would still be over the specified line" " length and can't be split further by StringSplitter." ) - return Ok(string_idx) + return Ok([string_idx]) return TErr("This line does not contain any non-atomic strings.") @@ -1547,7 +1965,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert": is_valid_index = is_valid_index_factory(LL) - for (i, leaf) in enumerate(LL): + for i, leaf in enumerate(LL): # We MUST find a comma... if leaf.type == token.COMMA: idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1 @@ -1585,7 +2003,7 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): ): is_valid_index = is_valid_index_factory(LL) - for (i, leaf) in enumerate(LL): + for i, leaf in enumerate(LL): # We MUST find either an '=' or '+=' symbol... if leaf.type in [token.EQUAL, token.PLUSEQUAL]: idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1 @@ -1614,23 +2032,24 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): return None @staticmethod - def _dict_match(LL: List[Leaf]) -> Optional[int]: + def _dict_or_lambda_match(LL: List[Leaf]) -> Optional[int]: """ Returns: string_idx such that @LL[string_idx] is equal to our target (i.e. matched) string, if this line matches the dictionary key assignment - statement requirements listed in the 'Requirements' section of this - classes' docstring. + statement or lambda expression requirements listed in the + 'Requirements' section of this classes' docstring. OR None, otherwise. """ - # If this line is apart of a dictionary key assignment... - if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]: + # If this line is a part of a dictionary key assignment or lambda expression... + parent_types = [parent_type(LL[0]), parent_type(LL[0].parent)] + if syms.dictsetmaker in parent_types or syms.lambdef in parent_types: is_valid_index = is_valid_index_factory(LL) - for (i, leaf) in enumerate(LL): - # We MUST find a colon... - if leaf.type == token.COLON: + for i, leaf in enumerate(LL): + # We MUST find a colon, it can either be dict's or lambda's colon... + if leaf.type == token.COLON and i < len(LL) - 1: idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1 # That colon MUST be followed by a string... @@ -1651,8 +2070,15 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): return None - def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + def do_transform( + self, line: Line, string_indices: List[int] + ) -> Iterator[TResult[Line]]: LL = line.leaves + assert len(string_indices) == 1, ( + f"{self.__class__.__name__} should only find one match at a time, found" + f" {len(string_indices)}" + ) + string_idx = string_indices[0] is_valid_index = is_valid_index_factory(LL) insert_str_child = insert_str_child_factory(LL[string_idx]) @@ -1724,6 +2150,25 @@ class StringParenWrapper(CustomSplitMapMixin, BaseStringSplitter): f" (left_leaves={left_leaves}, right_leaves={right_leaves})" ) old_rpar_leaf = right_leaves.pop() + elif right_leaves and right_leaves[-1].type == token.RPAR: + # Special case for lambda expressions as dict's value, e.g.: + # my_dict = { + # "key": lambda x: f"formatted: {x}, + # } + # After wrapping the dict's value with parentheses, the string is + # followed by a RPAR but its opening bracket is lambda's, not + # the string's: + # "key": (lambda x: f"formatted: {x}), + opening_bracket = right_leaves[-1].opening_bracket + if opening_bracket is not None and opening_bracket in left_leaves: + index = left_leaves.index(opening_bracket) + if ( + index > 0 + and index < len(left_leaves) - 1 + and left_leaves[index - 1].type == token.COLON + and left_leaves[index + 1].value == "lambda" + ): + right_leaves.pop() append_leaves(string_line, line, right_leaves) @@ -1780,20 +2225,20 @@ class StringParser: ``` """ - DEFAULT_TOKEN = -1 + DEFAULT_TOKEN: Final = 20210605 # String Parser States - START = 1 - DOT = 2 - NAME = 3 - PERCENT = 4 - SINGLE_FMT_ARG = 5 - LPAR = 6 - RPAR = 7 - DONE = 8 + START: Final = 1 + DOT: Final = 2 + NAME: Final = 3 + PERCENT: Final = 4 + SINGLE_FMT_ARG: Final = 5 + LPAR: Final = 6 + RPAR: Final = 7 + DONE: Final = 8 # Lookup Table for Next State - _goto: Dict[Tuple[ParserState, NodeType], ParserState] = { + _goto: Final[Dict[Tuple[ParserState, NodeType], ParserState]] = { # A string trailer may start with '.' OR '%'. (START, token.DOT): DOT, (START, token.PERCENT): PERCENT, @@ -1830,7 +2275,7 @@ class StringParser: Returns: The index directly after the last leaf which is apart of the string trailer, if a "trailer" exists. - OR + OR @string_idx + 1, if no string "trailer" exists. """ assert leaves[string_idx].type == token.STRING @@ -1844,11 +2289,11 @@ class StringParser: """ Pre-conditions: * On the first call to this function, @leaf MUST be the leaf that - was directly after the string leaf in question (e.g. if our target - string is `line.leaves[i]` then the first call to this method must - be `line.leaves[i + 1]`). + was directly after the string leaf in question (e.g. if our target + string is `line.leaves[i]` then the first call to this method must + be `line.leaves[i + 1]`). * On the next call to this function, the leaf parameter passed in - MUST be the leaf directly following @leaf. + MUST be the leaf directly following @leaf. Returns: True iff @leaf is apart of the string's trailer.