X-Git-Url: https://git.madduck.net/etc/vim.git/blobdiff_plain/b73b77a9b039d5afa6ab32202bff2b2232258d45..d4ff985853c8d140d73b9d362604deedb41eb20e:/src/black/trans.py diff --git a/src/black/trans.py b/src/black/trans.py index 8893ab0..ec07f5e 100644 --- a/src/black/trans.py +++ b/src/black/trans.py @@ -30,7 +30,6 @@ else: from mypy_extensions import trait -from black.brackets import BracketMatchError from black.comments import contains_pragma_comment from black.lines import Line, append_leaves from black.mode import Feature @@ -41,6 +40,7 @@ from black.nodes import ( is_empty_lpar, is_empty_par, is_empty_rpar, + is_part_of_annotation, parent_type, replace_child, syms, @@ -69,7 +69,7 @@ NodeType = int ParserState = int StringID = int TResult = Result[T, CannotTransform] # (T)ransform Result -TMatchResult = TResult[Index] +TMatchResult = TResult[List[Index]] def TErr(err_msg: str) -> Err[CannotTransform]: @@ -198,14 +198,19 @@ class StringTransformer(ABC): def do_match(self, line: Line) -> TMatchResult: """ Returns: - * Ok(string_idx) such that `line.leaves[string_idx]` is our target - string, if a match was able to be made. + * Ok(string_indices) such that for each index, `line.leaves[index]` + is our target string if a match was able to be made. For + transformers that don't result in more lines (e.g. StringMerger, + StringParenStripper), multiple matches and transforms are done at + once to reduce the complexity. OR - * Err(CannotTransform), if a match was not able to be made. + * Err(CannotTransform), if no match could be made. """ @abstractmethod - def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + def do_transform( + self, line: Line, string_indices: List[int] + ) -> Iterator[TResult[Line]]: """ Yields: * Ok(new_line) where new_line is the new transformed line. @@ -246,9 +251,9 @@ class StringTransformer(ABC): " this line as one that it can transform." ) from cant_transform - string_idx = match_result.ok() + string_indices = match_result.ok() - for line_result in self.do_transform(line, string_idx): + for line_result in self.do_transform(line, string_indices): if isinstance(line_result, Err): cant_transform = line_result.err() raise CannotTransform( @@ -351,7 +356,7 @@ class StringMerger(StringTransformer, CustomSplitMapMixin): Requirements: (A) The line contains adjacent strings such that ALL of the validation checks - listed in StringMerger.__validate_msg(...)'s docstring pass. + listed in StringMerger._validate_msg(...)'s docstring pass. OR (B) The line contains a string which uses line continuation backslashes. @@ -371,28 +376,50 @@ class StringMerger(StringTransformer, CustomSplitMapMixin): is_valid_index = is_valid_index_factory(LL) - for i, leaf in enumerate(LL): + string_indices = [] + idx = 0 + while is_valid_index(idx): + leaf = LL[idx] if ( leaf.type == token.STRING - and is_valid_index(i + 1) - and LL[i + 1].type == token.STRING + and is_valid_index(idx + 1) + and LL[idx + 1].type == token.STRING ): - return Ok(i) + if not is_part_of_annotation(leaf): + string_indices.append(idx) + + # Advance to the next non-STRING leaf. + idx += 2 + while is_valid_index(idx) and LL[idx].type == token.STRING: + idx += 1 - if leaf.type == token.STRING and "\\\n" in leaf.value: - return Ok(i) + elif leaf.type == token.STRING and "\\\n" in leaf.value: + string_indices.append(idx) + # Advance to the next non-STRING leaf. + idx += 1 + while is_valid_index(idx) and LL[idx].type == token.STRING: + idx += 1 + + else: + idx += 1 - return TErr("This line has no strings that need merging.") + if string_indices: + return Ok(string_indices) + else: + return TErr("This line has no strings that need merging.") - def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + def do_transform( + self, line: Line, string_indices: List[int] + ) -> Iterator[TResult[Line]]: new_line = line + rblc_result = self._remove_backslash_line_continuation_chars( - new_line, string_idx + new_line, string_indices ) if isinstance(rblc_result, Ok): new_line = rblc_result.ok() - msg_result = self._merge_string_group(new_line, string_idx) + msg_result = self._merge_string_group(new_line, string_indices) if isinstance(msg_result, Ok): new_line = msg_result.ok() @@ -413,7 +440,7 @@ class StringMerger(StringTransformer, CustomSplitMapMixin): @staticmethod def _remove_backslash_line_continuation_chars( - line: Line, string_idx: int + line: Line, string_indices: List[int] ) -> TResult[Line]: """ Merge strings that were split across multiple lines using @@ -427,34 +454,44 @@ class StringMerger(StringTransformer, CustomSplitMapMixin): """ LL = line.leaves - string_leaf = LL[string_idx] - if not ( - string_leaf.type == token.STRING - and "\\\n" in string_leaf.value - and not has_triple_quotes(string_leaf.value) - ): + indices_to_transform = [] + for string_idx in string_indices: + string_leaf = LL[string_idx] + if ( + string_leaf.type == token.STRING + and "\\\n" in string_leaf.value + and not has_triple_quotes(string_leaf.value) + ): + indices_to_transform.append(string_idx) + + if not indices_to_transform: return TErr( - f"String leaf {string_leaf} does not contain any backslash line" - " continuation characters." + "Found no string leaves that contain backslash line continuation" + " characters." ) new_line = line.clone() new_line.comments = line.comments.copy() append_leaves(new_line, line, LL) - new_string_leaf = new_line.leaves[string_idx] - new_string_leaf.value = new_string_leaf.value.replace("\\\n", "") + for string_idx in indices_to_transform: + new_string_leaf = new_line.leaves[string_idx] + new_string_leaf.value = new_string_leaf.value.replace("\\\n", "") return Ok(new_line) - def _merge_string_group(self, line: Line, string_idx: int) -> TResult[Line]: + def _merge_string_group( + self, line: Line, string_indices: List[int] + ) -> TResult[Line]: """ - Merges string group (i.e. set of adjacent strings) where the first - string in the group is `line.leaves[string_idx]`. + Merges string groups (i.e. set of adjacent strings). + + Each index from `string_indices` designates one string group's first + leaf in `line.leaves`. Returns: Ok(new_line), if ALL of the validation checks found in - __validate_msg(...) pass. + _validate_msg(...) pass. OR Err(CannotTransform), otherwise. """ @@ -462,10 +499,54 @@ class StringMerger(StringTransformer, CustomSplitMapMixin): is_valid_index = is_valid_index_factory(LL) - vresult = self._validate_msg(line, string_idx) - if isinstance(vresult, Err): - return vresult + # A dict of {string_idx: tuple[num_of_strings, string_leaf]}. + merged_string_idx_dict: Dict[int, Tuple[int, Leaf]] = {} + for string_idx in string_indices: + vresult = self._validate_msg(line, string_idx) + if isinstance(vresult, Err): + continue + merged_string_idx_dict[string_idx] = self._merge_one_string_group( + LL, string_idx, is_valid_index + ) + + if not merged_string_idx_dict: + return TErr("No string group is merged") + # Build the final line ('new_line') that this method will later return. + new_line = line.clone() + previous_merged_string_idx = -1 + previous_merged_num_of_strings = -1 + for i, leaf in enumerate(LL): + if i in merged_string_idx_dict: + previous_merged_string_idx = i + previous_merged_num_of_strings, string_leaf = merged_string_idx_dict[i] + new_line.append(string_leaf) + + if ( + previous_merged_string_idx + <= i + < previous_merged_string_idx + previous_merged_num_of_strings + ): + for comment_leaf in line.comments_after(LL[i]): + new_line.append(comment_leaf, preformatted=True) + continue + + append_leaves(new_line, line, [leaf]) + + return Ok(new_line) + + def _merge_one_string_group( + self, LL: List[Leaf], string_idx: int, is_valid_index: Callable[[int], bool] + ) -> Tuple[int, Leaf]: + """ + Merges one string group where the first string in the group is + `LL[string_idx]`. + + Returns: + A tuple of `(num_of_strings, leaf)` where `num_of_strings` is the + number of strings merged and `leaf` is the newly merged string + to be replaced in the new line. + """ # If the string group is wrapped inside an Atom node, we must make sure # to later replace that Atom with our new (merged) string leaf. atom_node = LL[string_idx].parent @@ -588,27 +669,14 @@ class StringMerger(StringTransformer, CustomSplitMapMixin): # Else replace the atom node with the new string leaf. replace_child(atom_node, string_leaf) - # Build the final line ('new_line') that this method will later return. - new_line = line.clone() - for i, leaf in enumerate(LL): - if i == string_idx: - new_line.append(string_leaf) - - if string_idx <= i < string_idx + num_of_strings: - for comment_leaf in line.comments_after(LL[i]): - new_line.append(comment_leaf, preformatted=True) - continue - - append_leaves(new_line, line, [leaf]) - self.add_custom_splits(string_leaf.value, custom_splits) - return Ok(new_line) + return num_of_strings, string_leaf @staticmethod def _validate_msg(line: Line, string_idx: int) -> TResult[None]: """Validate (M)erge (S)tring (G)roup - Transform-time string validation logic for __merge_string_group(...). + Transform-time string validation logic for _merge_string_group(...). Returns: * Ok(None), if ALL validation checks (listed below) pass. @@ -622,6 +690,11 @@ class StringMerger(StringTransformer, CustomSplitMapMixin): - The set of all string prefixes in the string group is of length greater than one and is not equal to {"", "f"}. - The string group consists of raw strings. + - The string group is stringified type annotations. We don't want to + process stringified type annotations since pyright doesn't support + them spanning multiple string values. (NOTE: mypy, pytype, pyre do + support them, so we can change if pyright also gains support in the + future. See https://github.com/microsoft/pyright/issues/4359.) """ # We first check for "inner" stand-alone comments (i.e. stand-alone # comments that have a string leaf before them AND after them). @@ -711,7 +784,15 @@ class StringParenStripper(StringTransformer): is_valid_index = is_valid_index_factory(LL) - for idx, leaf in enumerate(LL): + string_indices = [] + + idx = -1 + while True: + idx += 1 + if idx >= len(LL): + break + leaf = LL[idx] + # Should be a string... if leaf.type != token.STRING: continue @@ -793,45 +874,73 @@ class StringParenStripper(StringTransformer): }: continue - return Ok(string_idx) + string_indices.append(string_idx) + idx = string_idx + while idx < len(LL) - 1 and LL[idx + 1].type == token.STRING: + idx += 1 + if string_indices: + return Ok(string_indices) return TErr("This line has no strings wrapped in parens.") - def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + def do_transform( + self, line: Line, string_indices: List[int] + ) -> Iterator[TResult[Line]]: LL = line.leaves - string_parser = StringParser() - rpar_idx = string_parser.parse(LL, string_idx) + string_and_rpar_indices: List[int] = [] + for string_idx in string_indices: + string_parser = StringParser() + rpar_idx = string_parser.parse(LL, string_idx) + + should_transform = True + for leaf in (LL[string_idx - 1], LL[rpar_idx]): + if line.comments_after(leaf): + # Should not strip parentheses which have comments attached + # to them. + should_transform = False + break + if should_transform: + string_and_rpar_indices.extend((string_idx, rpar_idx)) - for leaf in (LL[string_idx - 1], LL[rpar_idx]): - if line.comments_after(leaf): - yield TErr( - "Will not strip parentheses which have comments attached to them." - ) - return + if string_and_rpar_indices: + yield Ok(self._transform_to_new_line(line, string_and_rpar_indices)) + else: + yield Err( + CannotTransform("All string groups have comments attached to them.") + ) + + def _transform_to_new_line( + self, line: Line, string_and_rpar_indices: List[int] + ) -> Line: + LL = line.leaves new_line = line.clone() new_line.comments = line.comments.copy() - try: - append_leaves(new_line, line, LL[: string_idx - 1]) - except BracketMatchError: - # HACK: I believe there is currently a bug somewhere in - # right_hand_split() that is causing brackets to not be tracked - # properly by a shared BracketTracker. - append_leaves(new_line, line, LL[: string_idx - 1], preformatted=True) - - string_leaf = Leaf(token.STRING, LL[string_idx].value) - LL[string_idx - 1].remove() - replace_child(LL[string_idx], string_leaf) - new_line.append(string_leaf) - - append_leaves( - new_line, line, LL[string_idx + 1 : rpar_idx] + LL[rpar_idx + 1 :] - ) - LL[rpar_idx].remove() + previous_idx = -1 + # We need to sort the indices, since string_idx and its matching + # rpar_idx may not come in order, e.g. in + # `("outer" % ("inner".join(items)))`, the "inner" string's + # string_idx is smaller than "outer" string's rpar_idx. + for idx in sorted(string_and_rpar_indices): + leaf = LL[idx] + lpar_or_rpar_idx = idx - 1 if leaf.type == token.STRING else idx + append_leaves(new_line, line, LL[previous_idx + 1 : lpar_or_rpar_idx]) + if leaf.type == token.STRING: + string_leaf = Leaf(token.STRING, LL[idx].value) + LL[lpar_or_rpar_idx].remove() # Remove lpar. + replace_child(LL[idx], string_leaf) + new_line.append(string_leaf) + else: + LL[lpar_or_rpar_idx].remove() # This is a rpar. + + previous_idx = idx + + # Append the leaves after the last idx: + append_leaves(new_line, line, LL[idx + 1 :]) - yield Ok(new_line) + return new_line class BaseStringSplitter(StringTransformer): @@ -884,7 +993,12 @@ class BaseStringSplitter(StringTransformer): if isinstance(match_result, Err): return match_result - string_idx = match_result.ok() + string_indices = match_result.ok() + assert len(string_indices) == 1, ( + f"{self.__class__.__name__} should only find one match at a time, found" + f" {len(string_indices)}" + ) + string_idx = string_indices[0] vresult = self._validate(line, string_idx) if isinstance(vresult, Err): return vresult @@ -1218,10 +1332,17 @@ class StringSplitter(BaseStringSplitter, CustomSplitMapMixin): if is_valid_index(idx): return TErr("This line does not end with a string.") - return Ok(string_idx) + return Ok([string_idx]) - def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + def do_transform( + self, line: Line, string_indices: List[int] + ) -> Iterator[TResult[Line]]: LL = line.leaves + assert len(string_indices) == 1, ( + f"{self.__class__.__name__} should only find one match at a time, found" + f" {len(string_indices)}" + ) + string_idx = string_indices[0] QUOTE = LL[string_idx].value[-1] @@ -1359,9 +1480,14 @@ class StringSplitter(BaseStringSplitter, CustomSplitMapMixin): # prefix, and the current custom split did NOT originally use a # prefix... if ( - next_value != self._normalize_f_string(next_value, prefix) - and use_custom_breakpoints + use_custom_breakpoints and not csplit.has_prefix + and ( + # `next_value == prefix + QUOTE` happens when the custom + # split is an empty string. + next_value == prefix + QUOTE + or next_value != self._normalize_f_string(next_value, prefix) + ) ): # Then `csplit.break_idx` will be off by one after removing # the 'f' prefix. @@ -1638,6 +1764,8 @@ class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin): * The line is a dictionary key assignment where some valid key is being assigned the value of some string. OR + * The line is an lambda expression and the value is a string. + OR * The line starts with an "atom" string that prefers to be wrapped in parens. It's preferred to be wrapped when the string is surrounded by commas (or is the first/last child). @@ -1683,7 +1811,7 @@ class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin): or self._else_match(LL) or self._assert_match(LL) or self._assign_match(LL) - or self._dict_match(LL) + or self._dict_or_lambda_match(LL) or self._prefer_paren_wrap_match(LL) ) @@ -1702,7 +1830,7 @@ class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin): " resultant line would still be over the specified line" " length and can't be split further by StringSplitter." ) - return Ok(string_idx) + return Ok([string_idx]) return TErr("This line does not contain any non-atomic strings.") @@ -1841,23 +1969,24 @@ class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin): return None @staticmethod - def _dict_match(LL: List[Leaf]) -> Optional[int]: + def _dict_or_lambda_match(LL: List[Leaf]) -> Optional[int]: """ Returns: string_idx such that @LL[string_idx] is equal to our target (i.e. matched) string, if this line matches the dictionary key assignment - statement requirements listed in the 'Requirements' section of this - classes' docstring. + statement or lambda expression requirements listed in the + 'Requirements' section of this classes' docstring. OR None, otherwise. """ - # If this line is apart of a dictionary key assignment... - if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]: + # If this line is a part of a dictionary key assignment or lambda expression... + parent_types = [parent_type(LL[0]), parent_type(LL[0].parent)] + if syms.dictsetmaker in parent_types or syms.lambdef in parent_types: is_valid_index = is_valid_index_factory(LL) for i, leaf in enumerate(LL): - # We MUST find a colon... - if leaf.type == token.COLON: + # We MUST find a colon, it can either be dict's or lambda's colon... + if leaf.type == token.COLON and i < len(LL) - 1: idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1 # That colon MUST be followed by a string... @@ -1878,8 +2007,15 @@ class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin): return None - def do_transform(self, line: Line, string_idx: int) -> Iterator[TResult[Line]]: + def do_transform( + self, line: Line, string_indices: List[int] + ) -> Iterator[TResult[Line]]: LL = line.leaves + assert len(string_indices) == 1, ( + f"{self.__class__.__name__} should only find one match at a time, found" + f" {len(string_indices)}" + ) + string_idx = string_indices[0] is_valid_index = is_valid_index_factory(LL) insert_str_child = insert_str_child_factory(LL[string_idx]) @@ -1951,6 +2087,25 @@ class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin): f" (left_leaves={left_leaves}, right_leaves={right_leaves})" ) old_rpar_leaf = right_leaves.pop() + elif right_leaves and right_leaves[-1].type == token.RPAR: + # Special case for lambda expressions as dict's value, e.g.: + # my_dict = { + # "key": lambda x: f"formatted: {x}, + # } + # After wrapping the dict's value with parentheses, the string is + # followed by a RPAR but its opening bracket is lambda's, not + # the string's: + # "key": (lambda x: f"formatted: {x}), + opening_bracket = right_leaves[-1].opening_bracket + if opening_bracket is not None and opening_bracket in left_leaves: + index = left_leaves.index(opening_bracket) + if ( + index > 0 + and index < len(left_leaves) - 1 + and left_leaves[index - 1].type == token.COLON + and left_leaves[index + 1].value == "lambda" + ): + right_leaves.pop() append_leaves(string_line, line, right_leaves)