]> git.madduck.net Git - etc/vim.git/blobdiff - src/black/trans.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Copy over comments when hugging power ops (#2874)
[etc/vim.git] / src / black / trans.py
index 6aca3a8733f1f2f74624e6ad7aaba3e9e70313c6..28d9250adc1fe78d7e788b2d6fd382d00d42073b 100644 (file)
@@ -4,7 +4,7 @@ String transformers that can split and merge strings.
 from abc import ABC, abstractmethod
 from collections import defaultdict
 from dataclasses import dataclass
-import regex as re  # We need recursive patterns here (?R)
+import re
 from typing import (
     Any,
     Callable,
@@ -24,9 +24,9 @@ from typing import (
 import sys
 
 if sys.version_info < (3, 8):
-    from typing_extensions import Final
+    from typing_extensions import Literal, Final
 else:
-    from typing import Final
+    from typing import Literal, Final
 
 from mypy_extensions import trait
 
@@ -71,6 +71,84 @@ def TErr(err_msg: str) -> Err[CannotTransform]:
     return Err(cant_transform)
 
 
+def hug_power_op(line: Line, features: Collection[Feature]) -> Iterator[Line]:
+    """A transformer which normalizes spacing around power operators."""
+
+    # Performance optimization to avoid unnecessary Leaf clones and other ops.
+    for leaf in line.leaves:
+        if leaf.type == token.DOUBLESTAR:
+            break
+    else:
+        raise CannotTransform("No doublestar token was found in the line.")
+
+    def is_simple_lookup(index: int, step: Literal[1, -1]) -> bool:
+        # Brackets and parentheses indicate calls, subscripts, etc. ...
+        # basically stuff that doesn't count as "simple". Only a NAME lookup
+        # or dotted lookup (eg. NAME.NAME) is OK.
+        if step == -1:
+            disallowed = {token.RPAR, token.RSQB}
+        else:
+            disallowed = {token.LPAR, token.LSQB}
+
+        while 0 <= index < len(line.leaves):
+            current = line.leaves[index]
+            if current.type in disallowed:
+                return False
+            if current.type not in {token.NAME, token.DOT} or current.value == "for":
+                # If the current token isn't disallowed, we'll assume this is simple as
+                # only the disallowed tokens are semantically attached to this lookup
+                # expression we're checking. Also, stop early if we hit the 'for' bit
+                # of a comprehension.
+                return True
+
+            index += step
+
+        return True
+
+    def is_simple_operand(index: int, kind: Literal["base", "exponent"]) -> bool:
+        # An operand is considered "simple" if's a NAME, a numeric CONSTANT, a simple
+        # lookup (see above), with or without a preceding unary operator.
+        start = line.leaves[index]
+        if start.type in {token.NAME, token.NUMBER}:
+            return is_simple_lookup(index, step=(1 if kind == "exponent" else -1))
+
+        if start.type in {token.PLUS, token.MINUS, token.TILDE}:
+            if line.leaves[index + 1].type in {token.NAME, token.NUMBER}:
+                # step is always one as bases with a preceding unary op will be checked
+                # for simplicity starting from the next token (so it'll hit the check
+                # above).
+                return is_simple_lookup(index + 1, step=1)
+
+        return False
+
+    new_line = line.clone()
+    should_hug = False
+    for idx, leaf in enumerate(line.leaves):
+        new_leaf = leaf.clone()
+        if should_hug:
+            new_leaf.prefix = ""
+            should_hug = False
+
+        should_hug = (
+            (0 < idx < len(line.leaves) - 1)
+            and leaf.type == token.DOUBLESTAR
+            and is_simple_operand(idx - 1, kind="base")
+            and line.leaves[idx - 1].value != "lambda"
+            and is_simple_operand(idx + 1, kind="exponent")
+        )
+        if should_hug:
+            new_leaf.prefix = ""
+
+        # We have to be careful to make a new line properly:
+        # - bracket related metadata must be maintained (handled by Line.append)
+        # - comments need to copied over, updating the leaf IDs they're attached to
+        new_line.append(new_leaf, preformatted=True)
+        for comment_leaf in line.comments_after(leaf):
+            new_line.append(comment_leaf, preformatted=True)
+
+    yield new_line
+
+
 class StringTransformer(ABC):
     """
     An implementation of the Transformer protocol that relies on its
@@ -283,7 +361,7 @@ class StringMerger(StringTransformer, CustomSplitMapMixin):
 
         is_valid_index = is_valid_index_factory(LL)
 
-        for (i, leaf) in enumerate(LL):
+        for i, leaf in enumerate(LL):
             if (
                 leaf.type == token.STRING
                 and is_valid_index(i + 1)
@@ -453,7 +531,7 @@ class StringMerger(StringTransformer, CustomSplitMapMixin):
             # with 'f'...
             if "f" in prefix and "f" not in next_prefix:
                 # Then we must escape any braces contained in this substring.
-                SS = re.subf(r"(\{|\})", "{1}{1}", SS)
+                SS = re.sub(r"(\{|\})", r"\1\1", SS)
 
             NSS = make_naked(SS, next_prefix)
 
@@ -488,7 +566,7 @@ class StringMerger(StringTransformer, CustomSplitMapMixin):
 
         # Build the final line ('new_line') that this method will later return.
         new_line = line.clone()
-        for (i, leaf) in enumerate(LL):
+        for i, leaf in enumerate(LL):
             if i == string_idx:
                 new_line.append(string_leaf)
 
@@ -609,7 +687,7 @@ class StringParenStripper(StringTransformer):
 
         is_valid_index = is_valid_index_factory(LL)
 
-        for (idx, leaf) in enumerate(LL):
+        for idx, leaf in enumerate(LL):
             # Should be a string...
             if leaf.type != token.STRING:
                 continue
@@ -1631,7 +1709,7 @@ class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin):
         if parent_type(LL[0]) == syms.assert_stmt and LL[0].value == "assert":
             is_valid_index = is_valid_index_factory(LL)
 
-            for (i, leaf) in enumerate(LL):
+            for i, leaf in enumerate(LL):
                 # We MUST find a comma...
                 if leaf.type == token.COMMA:
                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
@@ -1669,7 +1747,7 @@ class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin):
         ):
             is_valid_index = is_valid_index_factory(LL)
 
-            for (i, leaf) in enumerate(LL):
+            for i, leaf in enumerate(LL):
                 # We MUST find either an '=' or '+=' symbol...
                 if leaf.type in [token.EQUAL, token.PLUSEQUAL]:
                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1
@@ -1712,7 +1790,7 @@ class StringParenWrapper(BaseStringSplitter, CustomSplitMapMixin):
         if syms.dictsetmaker in [parent_type(LL[0]), parent_type(LL[0].parent)]:
             is_valid_index = is_valid_index_factory(LL)
 
-            for (i, leaf) in enumerate(LL):
+            for i, leaf in enumerate(LL):
                 # We MUST find a colon...
                 if leaf.type == token.COLON:
                     idx = i + 2 if is_empty_par(LL[i + 1]) else i + 1