]> git.madduck.net Git - etc/vim.git/blobdiff - blib2to3/pytree.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Blacken .py files in blib2to3 (#1011)
[etc/vim.git] / blib2to3 / pytree.py
index 693366f7b2e4fb3bba6ad237aefa8bfca7904dc3..6776491cfbf26aa69bdb46977cfeefdff91f4221 100644 (file)
@@ -18,16 +18,21 @@ from io import StringIO
 HUGE = 0x7FFFFFFF  # maximum repeat count, default max
 
 _type_reprs = {}
+
+
 def type_repr(type_num):
     global _type_reprs
     if not _type_reprs:
         from .pygram import python_symbols
+
         # printing tokens is possible but not as useful
         # from .pgen2 import token // token.__dict__.items():
         for name, val in python_symbols.__dict__.items():
-            if type(val) == int: _type_reprs[val] = name
+            if type(val) == int:
+                _type_reprs[val] = name
     return _type_reprs.setdefault(type_num, type_num)
 
+
 class Base(object):
 
     """
@@ -40,7 +45,7 @@ class Base(object):
     """
 
     # Default values for instance variables
-    type = None    # int: token number (< 256) or symbol number (>= 256)
+    type = None  # int: token number (< 256) or symbol number (>= 256)
     parent = None  # Parent node pointer, or None
     children = ()  # Tuple of subnodes
     was_changed = False
@@ -61,7 +66,7 @@ class Base(object):
             return NotImplemented
         return self._eq(other)
 
-    __hash__ = None # For Py3 compatibility.
+    __hash__ = None  # For Py3 compatibility.
 
     def _eq(self, other):
         """
@@ -115,8 +120,9 @@ class Base(object):
             else:
                 l_children.append(ch)
         assert found, (self.children, self, new)
-        self.parent.changed()
         self.parent.children = l_children
+        self.parent.changed()
+        self.parent.invalidate_sibling_maps()
         for x in new:
             x.parent = self.parent
         self.parent = None
@@ -131,6 +137,8 @@ class Base(object):
         return node.lineno
 
     def changed(self):
+        if self.was_changed:
+            return
         if self.parent:
             self.parent.changed()
         self.was_changed = True
@@ -143,8 +151,9 @@ class Base(object):
         if self.parent:
             for i, node in enumerate(self.parent.children):
                 if node is self:
-                    self.parent.changed()
                     del self.parent.children[i]
+                    self.parent.changed()
+                    self.parent.invalidate_sibling_maps()
                     self.parent = None
                     return i
 
@@ -157,13 +166,9 @@ class Base(object):
         if self.parent is None:
             return None
 
-        # Can't use index(); we need to test by identity
-        for i, child in enumerate(self.parent.children):
-            if child is self:
-                try:
-                    return self.parent.children[i+1]
-                except IndexError:
-                    return None
+        if self.parent.next_sibling_map is None:
+            self.parent.update_sibling_maps()
+        return self.parent.next_sibling_map[id(self)]
 
     @property
     def prev_sibling(self):
@@ -174,12 +179,9 @@ class Base(object):
         if self.parent is None:
             return None
 
-        # Can't use index(); we need to test by identity
-        for i, child in enumerate(self.parent.children):
-            if child is self:
-                if i == 0:
-                    return None
-                return self.parent.children[i-1]
+        if self.parent.prev_sibling_map is None:
+            self.parent.update_sibling_maps()
+        return self.parent.prev_sibling_map[id(self)]
 
     def leaves(self):
         for child in self.children:
@@ -201,17 +203,16 @@ class Base(object):
         return next_sib.prefix
 
     if sys.version_info < (3, 0):
+
         def __str__(self):
             return str(self).encode("ascii")
 
+
 class Node(Base):
 
     """Concrete implementation for interior nodes."""
 
-    def __init__(self,type, children,
-                 context=None,
-                 prefix=None,
-                 fixers_applied=None):
+    def __init__(self, type, children, context=None, prefix=None, fixers_applied=None):
         """
         Initializer.
 
@@ -226,6 +227,7 @@ class Node(Base):
         for ch in self.children:
             assert ch.parent is None, repr(ch)
             ch.parent = self
+        self.invalidate_sibling_maps()
         if prefix is not None:
             self.prefix = prefix
         if fixers_applied:
@@ -235,9 +237,11 @@ class Node(Base):
 
     def __repr__(self):
         """Return a canonical string representation."""
-        return "%s(%s, %r)" % (self.__class__.__name__,
-                               type_repr(self.type),
-                               self.children)
+        return "%s(%s, %r)" % (
+            self.__class__.__name__,
+            type_repr(self.type),
+            self.children,
+        )
 
     def __unicode__(self):
         """
@@ -256,8 +260,11 @@ class Node(Base):
 
     def clone(self):
         """Return a cloned (deep) copy of self."""
-        return Node(self.type, [ch.clone() for ch in self.children],
-                    fixers_applied=self.fixers_applied)
+        return Node(
+            self.type,
+            [ch.clone() for ch in self.children],
+            fixers_applied=self.fixers_applied,
+        )
 
     def post_order(self):
         """Return a post-order iterator for the tree."""
@@ -294,6 +301,7 @@ class Node(Base):
         self.children[i].parent = None
         self.children[i] = child
         self.changed()
+        self.invalidate_sibling_maps()
 
     def insert_child(self, i, child):
         """
@@ -303,6 +311,7 @@ class Node(Base):
         child.parent = self
         self.children.insert(i, child)
         self.changed()
+        self.invalidate_sibling_maps()
 
     def append_child(self, child):
         """
@@ -312,6 +321,21 @@ class Node(Base):
         child.parent = self
         self.children.append(child)
         self.changed()
+        self.invalidate_sibling_maps()
+
+    def invalidate_sibling_maps(self):
+        self.prev_sibling_map = None
+        self.next_sibling_map = None
+
+    def update_sibling_maps(self):
+        self.prev_sibling_map = _prev = {}
+        self.next_sibling_map = _next = {}
+        previous = None
+        for current in self.children:
+            _prev[id(current)] = previous
+            _next[id(previous)] = current
+            previous = current
+        _next[id(current)] = None
 
 
 class Leaf(Base):
@@ -320,13 +344,10 @@ class Leaf(Base):
 
     # Default values for instance variables
     _prefix = ""  # Whitespace and comments preceding this token in the input
-    lineno = 0    # Line where this token starts in the input
-    column = 0    # Column where this token tarts in the input
+    lineno = 0  # Line where this token starts in the input
+    column = 0  # Column where this token starts in the input
 
-    def __init__(self, type, value,
-                 context=None,
-                 prefix=None,
-                 fixers_applied=[]):
+    def __init__(self, type, value, context=None, prefix=None, fixers_applied=[]):
         """
         Initializer.
 
@@ -345,9 +366,12 @@ class Leaf(Base):
     def __repr__(self):
         """Return a canonical string representation."""
         from .pgen2.token import tok_name
-        return "%s(%s, %r)" % (self.__class__.__name__,
-                               tok_name.get(self.type, self.type),
-                               self.value)
+
+        return "%s(%s, %r)" % (
+            self.__class__.__name__,
+            tok_name.get(self.type, self.type),
+            self.value,
+        )
 
     def __unicode__(self):
         """
@@ -366,9 +390,12 @@ class Leaf(Base):
 
     def clone(self):
         """Return a cloned (deep) copy of self."""
-        return Leaf(self.type, self.value,
-                    (self.prefix, (self.lineno, self.column)),
-                    fixers_applied=self.fixers_applied)
+        return Leaf(
+            self.type,
+            self.value,
+            (self.prefix, (self.lineno, self.column)),
+            fixers_applied=self.fixers_applied,
+        )
 
     def leaves(self):
         yield self
@@ -393,6 +420,7 @@ class Leaf(Base):
         self.changed()
         self._prefix = prefix
 
+
 def convert(gr, raw_node):
     """
     Convert raw node information to a Node or Leaf instance.
@@ -429,9 +457,9 @@ class BasePattern(object):
     """
 
     # Defaults for instance variables
-    type = None     # Node type (token if < 256, symbol if >= 256)
+    type = None  # Node type (token if < 256, symbol if >= 256)
     content = None  # Optional content matching pattern
-    name = None     # Optional name used to store match in results dict
+    name = None  # Optional name used to store match in results dict
 
     def __new__(cls, *args, **kwds):
         """Constructor that prevents BasePattern from being instantiated."""
@@ -499,7 +527,6 @@ class BasePattern(object):
 
 
 class LeafPattern(BasePattern):
-
     def __init__(self, type=None, content=None, name=None):
         """
         Initializer.  Takes optional type, content, and name.
@@ -646,7 +673,7 @@ class WildcardPattern(BasePattern):
             # Check sanity of alternatives
             assert len(content), repr(content)  # Can't have zero alternatives
             for alt in content:
-                assert len(alt), repr(alt) # Can have empty alternatives
+                assert len(alt), repr(alt)  # Can have empty alternatives
         self.content = content
         self.min = min
         self.max = max
@@ -655,20 +682,29 @@ class WildcardPattern(BasePattern):
     def optimize(self):
         """Optimize certain stacked wildcard patterns."""
         subpattern = None
-        if (self.content is not None and
-            len(self.content) == 1 and len(self.content[0]) == 1):
+        if (
+            self.content is not None
+            and len(self.content) == 1
+            and len(self.content[0]) == 1
+        ):
             subpattern = self.content[0][0]
         if self.min == 1 and self.max == 1:
             if self.content is None:
                 return NodePattern(name=self.name)
-            if subpattern is not None and  self.name == subpattern.name:
+            if subpattern is not None and self.name == subpattern.name:
                 return subpattern.optimize()
-        if (self.min <= 1 and isinstance(subpattern, WildcardPattern) and
-            subpattern.min <= 1 and self.name == subpattern.name):
-            return WildcardPattern(subpattern.content,
-                                   self.min*subpattern.min,
-                                   self.max*subpattern.max,
-                                   subpattern.name)
+        if (
+            self.min <= 1
+            and isinstance(subpattern, WildcardPattern)
+            and subpattern.min <= 1
+            and self.name == subpattern.name
+        ):
+            return WildcardPattern(
+                subpattern.content,
+                self.min * subpattern.min,
+                self.max * subpattern.max,
+                subpattern.name,
+            )
         return self
 
     def match(self, node, results=None):
@@ -784,7 +820,7 @@ class WildcardPattern(BasePattern):
         if count < self.max:
             for alt in self.content:
                 for c0, r0 in generate_matches(alt, nodes):
-                    for c1, r1 in self._recursive_matches(nodes[c0:], count+1):
+                    for c1, r1 in self._recursive_matches(nodes[c0:], count + 1):
                         r = {}
                         r.update(r0)
                         r.update(r1)
@@ -792,7 +828,6 @@ class WildcardPattern(BasePattern):
 
 
 class NegatedPattern(BasePattern):
-
     def __init__(self, content=None):
         """
         Initializer.