]> git.madduck.net Git - etc/vim.git/blobdiff - blib2to3/pytree.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Fix for "# fmt: on" with decorators (#1325)
[etc/vim.git] / blib2to3 / pytree.py
index 6776491cfbf26aa69bdb46977cfeefdff91f4221..d1bcbe98a3381b987b59cd421e49c16f65014df4 100644 (file)
@@ -10,29 +10,56 @@ even the comments and whitespace between tokens.
 There's also a pattern matching implementation here.
 """
 
+# mypy: allow-untyped-defs
+
+from typing import (
+    Any,
+    Callable,
+    Dict,
+    Iterator,
+    List,
+    Optional,
+    Text,
+    Tuple,
+    TypeVar,
+    Union,
+    Set,
+    Iterable,
+    Sequence,
+)
+from blib2to3.pgen2.grammar import Grammar
+
 __author__ = "Guido van Rossum <guido@python.org>"
 
 import sys
 from io import StringIO
 
-HUGE = 0x7FFFFFFF  # maximum repeat count, default max
+HUGE: int = 0x7FFFFFFF  # maximum repeat count, default max
 
-_type_reprs = {}
+_type_reprs: Dict[int, Union[Text, int]] = {}
 
 
-def type_repr(type_num):
+def type_repr(type_num: int) -> Union[Text, int]:
     global _type_reprs
     if not _type_reprs:
         from .pygram import python_symbols
 
         # printing tokens is possible but not as useful
         # from .pgen2 import token // token.__dict__.items():
-        for name, val in python_symbols.__dict__.items():
+        for name in dir(python_symbols):
+            val = getattr(python_symbols, name)
             if type(val) == int:
                 _type_reprs[val] = name
     return _type_reprs.setdefault(type_num, type_num)
 
 
+_P = TypeVar("_P")
+
+NL = Union["Node", "Leaf"]
+Context = Tuple[Text, Tuple[int, int]]
+RawNode = Tuple[int, Optional[Text], Optional[Context], Optional[List[NL]]]
+
+
 class Base(object):
 
     """
@@ -45,18 +72,18 @@ class Base(object):
     """
 
     # Default values for instance variables
-    type = None  # int: token number (< 256) or symbol number (>= 256)
-    parent = None  # Parent node pointer, or None
-    children = ()  # Tuple of subnodes
-    was_changed = False
-    was_checked = False
+    type: int  # int: token number (< 256) or symbol number (>= 256)
+    parent: Optional["Node"] = None  # Parent node pointer, or None
+    children: List[NL]  # List of subnodes
+    was_changed: bool = False
+    was_checked: bool = False
 
     def __new__(cls, *args, **kwds):
         """Constructor that prevents Base from being instantiated."""
         assert cls is not Base, "Cannot instantiate Base"
         return object.__new__(cls)
 
-    def __eq__(self, other):
+    def __eq__(self, other: Any) -> bool:
         """
         Compare two nodes for equality.
 
@@ -66,9 +93,13 @@ class Base(object):
             return NotImplemented
         return self._eq(other)
 
-    __hash__ = None  # For Py3 compatibility.
+    __hash__ = None  # type: Any  # For Py3 compatibility.
+
+    @property
+    def prefix(self) -> Text:
+        raise NotImplementedError
 
-    def _eq(self, other):
+    def _eq(self: _P, other: _P) -> bool:
         """
         Compare two nodes for equality.
 
@@ -79,7 +110,7 @@ class Base(object):
         """
         raise NotImplementedError
 
-    def clone(self):
+    def clone(self: _P) -> _P:
         """
         Return a cloned (deep) copy of self.
 
@@ -87,7 +118,7 @@ class Base(object):
         """
         raise NotImplementedError
 
-    def post_order(self):
+    def post_order(self) -> Iterator[NL]:
         """
         Return a post-order iterator for the tree.
 
@@ -95,7 +126,7 @@ class Base(object):
         """
         raise NotImplementedError
 
-    def pre_order(self):
+    def pre_order(self) -> Iterator[NL]:
         """
         Return a pre-order iterator for the tree.
 
@@ -103,7 +134,7 @@ class Base(object):
         """
         raise NotImplementedError
 
-    def replace(self, new):
+    def replace(self, new: Union[NL, List[NL]]) -> None:
         """Replace this node with a new one in the parent."""
         assert self.parent is not None, str(self)
         assert new is not None
@@ -127,23 +158,23 @@ class Base(object):
             x.parent = self.parent
         self.parent = None
 
-    def get_lineno(self):
+    def get_lineno(self) -> Optional[int]:
         """Return the line number which generated the invocant node."""
         node = self
         while not isinstance(node, Leaf):
             if not node.children:
-                return
+                return None
             node = node.children[0]
         return node.lineno
 
-    def changed(self):
+    def changed(self) -> None:
         if self.was_changed:
             return
         if self.parent:
             self.parent.changed()
         self.was_changed = True
 
-    def remove(self):
+    def remove(self) -> Optional[int]:
         """
         Remove the node from the tree. Returns the position of the node in its
         parent's children before it was removed.
@@ -156,9 +187,10 @@ class Base(object):
                     self.parent.invalidate_sibling_maps()
                     self.parent = None
                     return i
+        return None
 
     @property
-    def next_sibling(self):
+    def next_sibling(self) -> Optional[NL]:
         """
         The node immediately following the invocant in their parent's children
         list. If the invocant does not have a next sibling, it is None
@@ -168,10 +200,11 @@ class Base(object):
 
         if self.parent.next_sibling_map is None:
             self.parent.update_sibling_maps()
+        assert self.parent.next_sibling_map is not None
         return self.parent.next_sibling_map[id(self)]
 
     @property
-    def prev_sibling(self):
+    def prev_sibling(self) -> Optional[NL]:
         """
         The node immediately preceding the invocant in their parent's children
         list. If the invocant does not have a previous sibling, it is None.
@@ -181,18 +214,19 @@ class Base(object):
 
         if self.parent.prev_sibling_map is None:
             self.parent.update_sibling_maps()
+        assert self.parent.prev_sibling_map is not None
         return self.parent.prev_sibling_map[id(self)]
 
-    def leaves(self):
+    def leaves(self) -> Iterator["Leaf"]:
         for child in self.children:
             yield from child.leaves()
 
-    def depth(self):
+    def depth(self) -> int:
         if self.parent is None:
             return 0
         return 1 + self.parent.depth()
 
-    def get_suffix(self):
+    def get_suffix(self) -> Text:
         """
         Return the string immediately following the invocant node. This is
         effectively equivalent to node.next_sibling.prefix
@@ -200,19 +234,25 @@ class Base(object):
         next_sib = self.next_sibling
         if next_sib is None:
             return ""
-        return next_sib.prefix
-
-    if sys.version_info < (3, 0):
-
-        def __str__(self):
-            return str(self).encode("ascii")
+        prefix = next_sib.prefix
+        return prefix
 
 
 class Node(Base):
 
     """Concrete implementation for interior nodes."""
 
-    def __init__(self, type, children, context=None, prefix=None, fixers_applied=None):
+    fixers_applied: Optional[List[Any]]
+    used_names: Optional[Set[Text]]
+
+    def __init__(
+        self,
+        type: int,
+        children: List[NL],
+        context: Optional[Any] = None,
+        prefix: Optional[Text] = None,
+        fixers_applied: Optional[List[Any]] = None,
+    ) -> None:
         """
         Initializer.
 
@@ -235,15 +275,16 @@ class Node(Base):
         else:
             self.fixers_applied = None
 
-    def __repr__(self):
+    def __repr__(self) -> Text:
         """Return a canonical string representation."""
+        assert self.type is not None
         return "%s(%s, %r)" % (
             self.__class__.__name__,
             type_repr(self.type),
             self.children,
         )
 
-    def __unicode__(self):
+    def __str__(self) -> Text:
         """
         Return a pretty string representation.
 
@@ -251,14 +292,12 @@ class Node(Base):
         """
         return "".join(map(str, self.children))
 
-    if sys.version_info > (3, 0):
-        __str__ = __unicode__
-
-    def _eq(self, other):
+    def _eq(self, other) -> bool:
         """Compare two nodes for equality."""
         return (self.type, self.children) == (other.type, other.children)
 
-    def clone(self):
+    def clone(self) -> "Node":
+        assert self.type is not None
         """Return a cloned (deep) copy of self."""
         return Node(
             self.type,
@@ -266,20 +305,20 @@ class Node(Base):
             fixers_applied=self.fixers_applied,
         )
 
-    def post_order(self):
+    def post_order(self) -> Iterator[NL]:
         """Return a post-order iterator for the tree."""
         for child in self.children:
             yield from child.post_order()
         yield self
 
-    def pre_order(self):
+    def pre_order(self) -> Iterator[NL]:
         """Return a pre-order iterator for the tree."""
         yield self
         for child in self.children:
             yield from child.pre_order()
 
     @property
-    def prefix(self):
+    def prefix(self) -> Text:
         """
         The whitespace and comments preceding this node in the input.
         """
@@ -288,11 +327,11 @@ class Node(Base):
         return self.children[0].prefix
 
     @prefix.setter
-    def prefix(self, prefix):
+    def prefix(self, prefix) -> None:
         if self.children:
             self.children[0].prefix = prefix
 
-    def set_child(self, i, child):
+    def set_child(self, i: int, child: NL) -> None:
         """
         Equivalent to 'node.children[i] = child'. This method also sets the
         child's parent attribute appropriately.
@@ -303,7 +342,7 @@ class Node(Base):
         self.changed()
         self.invalidate_sibling_maps()
 
-    def insert_child(self, i, child):
+    def insert_child(self, i: int, child: NL) -> None:
         """
         Equivalent to 'node.children.insert(i, child)'. This method also sets
         the child's parent attribute appropriately.
@@ -313,7 +352,7 @@ class Node(Base):
         self.changed()
         self.invalidate_sibling_maps()
 
-    def append_child(self, child):
+    def append_child(self, child: NL) -> None:
         """
         Equivalent to 'node.children.append(child)'. This method also sets the
         child's parent attribute appropriately.
@@ -323,14 +362,16 @@ class Node(Base):
         self.changed()
         self.invalidate_sibling_maps()
 
-    def invalidate_sibling_maps(self):
-        self.prev_sibling_map = None
-        self.next_sibling_map = None
+    def invalidate_sibling_maps(self) -> None:
+        self.prev_sibling_map: Optional[Dict[int, Optional[NL]]] = None
+        self.next_sibling_map: Optional[Dict[int, Optional[NL]]] = None
 
-    def update_sibling_maps(self):
-        self.prev_sibling_map = _prev = {}
-        self.next_sibling_map = _next = {}
-        previous = None
+    def update_sibling_maps(self) -> None:
+        _prev: Dict[int, Optional[NL]] = {}
+        _next: Dict[int, Optional[NL]] = {}
+        self.prev_sibling_map = _prev
+        self.next_sibling_map = _next
+        previous: Optional[NL] = None
         for current in self.children:
             _prev[id(current)] = previous
             _next[id(previous)] = current
@@ -343,17 +384,30 @@ class Leaf(Base):
     """Concrete implementation for leaf nodes."""
 
     # Default values for instance variables
+    value: Text
+    fixers_applied: List[Any]
+    bracket_depth: int
+    opening_bracket: "Leaf"
+    used_names: Optional[Set[Text]]
     _prefix = ""  # Whitespace and comments preceding this token in the input
-    lineno = 0  # Line where this token starts in the input
-    column = 0  # Column where this token starts in the input
-
-    def __init__(self, type, value, context=None, prefix=None, fixers_applied=[]):
+    lineno: int = 0  # Line where this token starts in the input
+    column: int = 0  # Column where this token starts in the input
+
+    def __init__(
+        self,
+        type: int,
+        value: Text,
+        context: Optional[Context] = None,
+        prefix: Optional[Text] = None,
+        fixers_applied: List[Any] = [],
+    ) -> None:
         """
         Initializer.
 
         Takes a type constant (a token number < 256), a string value, and an
         optional context keyword argument.
         """
+
         assert 0 <= type < 256, type
         if context is not None:
             self._prefix, (self.lineno, self.column) = context
@@ -361,19 +415,21 @@ class Leaf(Base):
         self.value = value
         if prefix is not None:
             self._prefix = prefix
-        self.fixers_applied = fixers_applied[:]
+        self.fixers_applied: Optional[List[Any]] = fixers_applied[:]
+        self.children = []
 
-    def __repr__(self):
+    def __repr__(self) -> str:
         """Return a canonical string representation."""
         from .pgen2.token import tok_name
 
+        assert self.type is not None
         return "%s(%s, %r)" % (
             self.__class__.__name__,
             tok_name.get(self.type, self.type),
             self.value,
         )
 
-    def __unicode__(self):
+    def __str__(self) -> Text:
         """
         Return a pretty string representation.
 
@@ -381,14 +437,12 @@ class Leaf(Base):
         """
         return self.prefix + str(self.value)
 
-    if sys.version_info > (3, 0):
-        __str__ = __unicode__
-
-    def _eq(self, other):
+    def _eq(self, other) -> bool:
         """Compare two nodes for equality."""
         return (self.type, self.value) == (other.type, other.value)
 
-    def clone(self):
+    def clone(self) -> "Leaf":
+        assert self.type is not None
         """Return a cloned (deep) copy of self."""
         return Leaf(
             self.type,
@@ -397,31 +451,31 @@ class Leaf(Base):
             fixers_applied=self.fixers_applied,
         )
 
-    def leaves(self):
+    def leaves(self) -> Iterator["Leaf"]:
         yield self
 
-    def post_order(self):
+    def post_order(self) -> Iterator["Leaf"]:
         """Return a post-order iterator for the tree."""
         yield self
 
-    def pre_order(self):
+    def pre_order(self) -> Iterator["Leaf"]:
         """Return a pre-order iterator for the tree."""
         yield self
 
     @property
-    def prefix(self):
+    def prefix(self) -> Text:
         """
         The whitespace and comments preceding this token in the input.
         """
         return self._prefix
 
     @prefix.setter
-    def prefix(self, prefix):
+    def prefix(self, prefix) -> None:
         self.changed()
         self._prefix = prefix
 
 
-def convert(gr, raw_node):
+def convert(gr: Grammar, raw_node: RawNode) -> NL:
     """
     Convert raw node information to a Node or Leaf instance.
 
@@ -433,11 +487,15 @@ def convert(gr, raw_node):
     if children or type in gr.number2symbol:
         # If there's exactly one child, return that child instead of
         # creating a new node.
+        assert children is not None
         if len(children) == 1:
             return children[0]
         return Node(type, children, context=context)
     else:
-        return Leaf(type, value, context=context)
+        return Leaf(type, value or "", context=context)
+
+
+_Results = Dict[Text, NL]
 
 
 class BasePattern(object):
@@ -457,22 +515,27 @@ class BasePattern(object):
     """
 
     # Defaults for instance variables
+    type: Optional[int]
     type = None  # Node type (token if < 256, symbol if >= 256)
-    content = None  # Optional content matching pattern
-    name = None  # Optional name used to store match in results dict
+    content: Any = None  # Optional content matching pattern
+    name: Optional[Text] = None  # Optional name used to store match in results dict
 
     def __new__(cls, *args, **kwds):
         """Constructor that prevents BasePattern from being instantiated."""
         assert cls is not BasePattern, "Cannot instantiate BasePattern"
         return object.__new__(cls)
 
-    def __repr__(self):
+    def __repr__(self) -> Text:
+        assert self.type is not None
         args = [type_repr(self.type), self.content, self.name]
         while args and args[-1] is None:
             del args[-1]
         return "%s(%s)" % (self.__class__.__name__, ", ".join(map(repr, args)))
 
-    def optimize(self):
+    def _submatch(self, node, results=None) -> bool:
+        raise NotImplementedError
+
+    def optimize(self) -> "BasePattern":
         """
         A subclass can define this as a hook for optimizations.
 
@@ -480,7 +543,7 @@ class BasePattern(object):
         """
         return self
 
-    def match(self, node, results=None):
+    def match(self, node: NL, results: Optional[_Results] = None) -> bool:
         """
         Does this pattern exactly match a node?
 
@@ -494,18 +557,19 @@ class BasePattern(object):
         if self.type is not None and node.type != self.type:
             return False
         if self.content is not None:
-            r = None
+            r: Optional[_Results] = None
             if results is not None:
                 r = {}
             if not self._submatch(node, r):
                 return False
             if r:
+                assert results is not None
                 results.update(r)
         if results is not None and self.name:
             results[self.name] = node
         return True
 
-    def match_seq(self, nodes, results=None):
+    def match_seq(self, nodes: List[NL], results: Optional[_Results] = None) -> bool:
         """
         Does this pattern exactly match a sequence of nodes?
 
@@ -515,19 +579,24 @@ class BasePattern(object):
             return False
         return self.match(nodes[0], results)
 
-    def generate_matches(self, nodes):
+    def generate_matches(self, nodes: List[NL]) -> Iterator[Tuple[int, _Results]]:
         """
         Generator yielding all matches for this pattern.
 
         Default implementation for non-wildcard patterns.
         """
-        r = {}
+        r: _Results = {}
         if nodes and self.match(nodes[0], r):
             yield 1, r
 
 
 class LeafPattern(BasePattern):
-    def __init__(self, type=None, content=None, name=None):
+    def __init__(
+        self,
+        type: Optional[int] = None,
+        content: Optional[Text] = None,
+        name: Optional[Text] = None,
+    ) -> None:
         """
         Initializer.  Takes optional type, content, and name.
 
@@ -547,7 +616,7 @@ class LeafPattern(BasePattern):
         self.content = content
         self.name = name
 
-    def match(self, node, results=None):
+    def match(self, node: NL, results=None):
         """Override match() to insist on a leaf node."""
         if not isinstance(node, Leaf):
             return False
@@ -571,9 +640,14 @@ class LeafPattern(BasePattern):
 
 class NodePattern(BasePattern):
 
-    wildcards = False
+    wildcards: bool = False
 
-    def __init__(self, type=None, content=None, name=None):
+    def __init__(
+        self,
+        type: Optional[int] = None,
+        content: Optional[Iterable[Text]] = None,
+        name: Optional[Text] = None,
+    ) -> None:
         """
         Initializer.  Takes optional type, content, and name.
 
@@ -593,16 +667,16 @@ class NodePattern(BasePattern):
             assert type >= 256, type
         if content is not None:
             assert not isinstance(content, str), repr(content)
-            content = list(content)
-            for i, item in enumerate(content):
+            newcontent = list(content)
+            for i, item in enumerate(newcontent):
                 assert isinstance(item, BasePattern), (i, item)
                 if isinstance(item, WildcardPattern):
                     self.wildcards = True
         self.type = type
-        self.content = content
+        self.content = newcontent
         self.name = name
 
-    def _submatch(self, node, results=None):
+    def _submatch(self, node, results=None) -> bool:
         """
         Match the pattern's content to the node's children.
 
@@ -644,7 +718,16 @@ class WildcardPattern(BasePattern):
     except it always uses non-greedy matching.
     """
 
-    def __init__(self, content=None, min=0, max=HUGE, name=None):
+    min: int
+    max: int
+
+    def __init__(
+        self,
+        content: Optional[Text] = None,
+        min: int = 0,
+        max: int = HUGE,
+        name: Optional[Text] = None,
+    ) -> None:
         """
         Initializer.
 
@@ -669,17 +752,20 @@ class WildcardPattern(BasePattern):
         """
         assert 0 <= min <= max <= HUGE, (min, max)
         if content is not None:
-            content = tuple(map(tuple, content))  # Protect against alterations
+            f = lambda s: tuple(s)
+            wrapped_content = tuple(map(f, content))  # Protect against alterations
             # Check sanity of alternatives
-            assert len(content), repr(content)  # Can't have zero alternatives
-            for alt in content:
+            assert len(wrapped_content), repr(
+                wrapped_content
+            )  # Can't have zero alternatives
+            for alt in wrapped_content:
                 assert len(alt), repr(alt)  # Can have empty alternatives
-        self.content = content
+        self.content = wrapped_content
         self.min = min
         self.max = max
         self.name = name
 
-    def optimize(self):
+    def optimize(self) -> Any:
         """Optimize certain stacked wildcard patterns."""
         subpattern = None
         if (
@@ -707,11 +793,11 @@ class WildcardPattern(BasePattern):
             )
         return self
 
-    def match(self, node, results=None):
+    def match(self, node, results=None) -> bool:
         """Does this pattern exactly match a node?"""
         return self.match_seq([node], results)
 
-    def match_seq(self, nodes, results=None):
+    def match_seq(self, nodes, results=None) -> bool:
         """Does this pattern exactly match a sequence of nodes?"""
         for c, r in self.generate_matches(nodes):
             if c == len(nodes):
@@ -722,7 +808,7 @@ class WildcardPattern(BasePattern):
                 return True
         return False
 
-    def generate_matches(self, nodes):
+    def generate_matches(self, nodes) -> Iterator[Tuple[int, _Results]]:
         """
         Generator yielding matches for a sequence of nodes.
 
@@ -767,7 +853,7 @@ class WildcardPattern(BasePattern):
                 if hasattr(sys, "getrefcount"):
                     sys.stderr = save_stderr
 
-    def _iterative_matches(self, nodes):
+    def _iterative_matches(self, nodes) -> Iterator[Tuple[int, _Results]]:
         """Helper to iteratively yield the matches."""
         nodelen = len(nodes)
         if 0 >= self.min:
@@ -796,10 +882,10 @@ class WildcardPattern(BasePattern):
                                 new_results.append((c0 + c1, r))
             results = new_results
 
-    def _bare_name_matches(self, nodes):
+    def _bare_name_matches(self, nodes) -> Tuple[int, _Results]:
         """Special optimized matcher for bare_name."""
         count = 0
-        r = {}
+        r = {}  # type: _Results
         done = False
         max = len(nodes)
         while not done and count < max:
@@ -809,10 +895,11 @@ class WildcardPattern(BasePattern):
                     count += 1
                     done = False
                     break
+        assert self.name is not None
         r[self.name] = nodes[:count]
         return count, r
 
-    def _recursive_matches(self, nodes, count):
+    def _recursive_matches(self, nodes, count) -> Iterator[Tuple[int, _Results]]:
         """Helper to recursively yield the matches."""
         assert self.content is not None
         if count >= self.min:
@@ -828,7 +915,7 @@ class WildcardPattern(BasePattern):
 
 
 class NegatedPattern(BasePattern):
-    def __init__(self, content=None):
+    def __init__(self, content: Optional[Any] = None) -> None:
         """
         Initializer.
 
@@ -841,15 +928,15 @@ class NegatedPattern(BasePattern):
             assert isinstance(content, BasePattern), repr(content)
         self.content = content
 
-    def match(self, node):
+    def match(self, node, results=None) -> bool:
         # We never match a node in its entirety
         return False
 
-    def match_seq(self, nodes):
+    def match_seq(self, nodes, results=None) -> bool:
         # We only match an empty sequence of nodes in its entirety
         return len(nodes) == 0
 
-    def generate_matches(self, nodes):
+    def generate_matches(self, nodes) -> Iterator[Tuple[int, _Results]]:
         if self.content is None:
             # Return a match if there is an empty sequence
             if len(nodes) == 0:
@@ -861,7 +948,9 @@ class NegatedPattern(BasePattern):
             yield 0, {}
 
 
-def generate_matches(patterns, nodes):
+def generate_matches(
+    patterns: List[BasePattern], nodes: List[NL]
+) -> Iterator[Tuple[int, _Results]]:
     """
     Generator yielding matches for a sequence of patterns and nodes.
 
@@ -887,3 +976,6 @@ def generate_matches(patterns, nodes):
                     r.update(r0)
                     r.update(r1)
                     yield c0 + c1, r
+
+
+_Convert = Callable[[Grammar, RawNode], Any]