]> git.madduck.net Git - etc/vim.git/blobdiff - blib2to3/pytree.py

madduck's git repository

Every one of the projects in this repository is available at the canonical URL git://git.madduck.net/madduck/pub/<projectpath> — see each project's metadata for the exact URL.

All patches and comments are welcome. Please squash your changes to logical commits before using git-format-patch and git-send-email to patches@git.madduck.net. If you'd read over the Git project's submission guidelines and adhered to them, I'd be especially grateful.

SSH access, as well as push access can be individually arranged.

If you use my repositories frequently, consider adding the following snippet to ~/.gitconfig and using the third clone URL listed for each project:

[url "git://git.madduck.net/madduck/"]
  insteadOf = madduck:

Simplify some code flow
[etc/vim.git] / blib2to3 / pytree.py
index 693366f7b2e4fb3bba6ad237aefa8bfca7904dc3..6776491cfbf26aa69bdb46977cfeefdff91f4221 100644 (file)
@@ -18,16 +18,21 @@ from io import StringIO
 HUGE = 0x7FFFFFFF  # maximum repeat count, default max
 
 _type_reprs = {}
 HUGE = 0x7FFFFFFF  # maximum repeat count, default max
 
 _type_reprs = {}
+
+
 def type_repr(type_num):
     global _type_reprs
     if not _type_reprs:
         from .pygram import python_symbols
 def type_repr(type_num):
     global _type_reprs
     if not _type_reprs:
         from .pygram import python_symbols
+
         # printing tokens is possible but not as useful
         # from .pgen2 import token // token.__dict__.items():
         for name, val in python_symbols.__dict__.items():
         # printing tokens is possible but not as useful
         # from .pgen2 import token // token.__dict__.items():
         for name, val in python_symbols.__dict__.items():
-            if type(val) == int: _type_reprs[val] = name
+            if type(val) == int:
+                _type_reprs[val] = name
     return _type_reprs.setdefault(type_num, type_num)
 
     return _type_reprs.setdefault(type_num, type_num)
 
+
 class Base(object):
 
     """
 class Base(object):
 
     """
@@ -40,7 +45,7 @@ class Base(object):
     """
 
     # Default values for instance variables
     """
 
     # Default values for instance variables
-    type = None    # int: token number (< 256) or symbol number (>= 256)
+    type = None  # int: token number (< 256) or symbol number (>= 256)
     parent = None  # Parent node pointer, or None
     children = ()  # Tuple of subnodes
     was_changed = False
     parent = None  # Parent node pointer, or None
     children = ()  # Tuple of subnodes
     was_changed = False
@@ -61,7 +66,7 @@ class Base(object):
             return NotImplemented
         return self._eq(other)
 
             return NotImplemented
         return self._eq(other)
 
-    __hash__ = None # For Py3 compatibility.
+    __hash__ = None  # For Py3 compatibility.
 
     def _eq(self, other):
         """
 
     def _eq(self, other):
         """
@@ -115,8 +120,9 @@ class Base(object):
             else:
                 l_children.append(ch)
         assert found, (self.children, self, new)
             else:
                 l_children.append(ch)
         assert found, (self.children, self, new)
-        self.parent.changed()
         self.parent.children = l_children
         self.parent.children = l_children
+        self.parent.changed()
+        self.parent.invalidate_sibling_maps()
         for x in new:
             x.parent = self.parent
         self.parent = None
         for x in new:
             x.parent = self.parent
         self.parent = None
@@ -131,6 +137,8 @@ class Base(object):
         return node.lineno
 
     def changed(self):
         return node.lineno
 
     def changed(self):
+        if self.was_changed:
+            return
         if self.parent:
             self.parent.changed()
         self.was_changed = True
         if self.parent:
             self.parent.changed()
         self.was_changed = True
@@ -143,8 +151,9 @@ class Base(object):
         if self.parent:
             for i, node in enumerate(self.parent.children):
                 if node is self:
         if self.parent:
             for i, node in enumerate(self.parent.children):
                 if node is self:
-                    self.parent.changed()
                     del self.parent.children[i]
                     del self.parent.children[i]
+                    self.parent.changed()
+                    self.parent.invalidate_sibling_maps()
                     self.parent = None
                     return i
 
                     self.parent = None
                     return i
 
@@ -157,13 +166,9 @@ class Base(object):
         if self.parent is None:
             return None
 
         if self.parent is None:
             return None
 
-        # Can't use index(); we need to test by identity
-        for i, child in enumerate(self.parent.children):
-            if child is self:
-                try:
-                    return self.parent.children[i+1]
-                except IndexError:
-                    return None
+        if self.parent.next_sibling_map is None:
+            self.parent.update_sibling_maps()
+        return self.parent.next_sibling_map[id(self)]
 
     @property
     def prev_sibling(self):
 
     @property
     def prev_sibling(self):
@@ -174,12 +179,9 @@ class Base(object):
         if self.parent is None:
             return None
 
         if self.parent is None:
             return None
 
-        # Can't use index(); we need to test by identity
-        for i, child in enumerate(self.parent.children):
-            if child is self:
-                if i == 0:
-                    return None
-                return self.parent.children[i-1]
+        if self.parent.prev_sibling_map is None:
+            self.parent.update_sibling_maps()
+        return self.parent.prev_sibling_map[id(self)]
 
     def leaves(self):
         for child in self.children:
 
     def leaves(self):
         for child in self.children:
@@ -201,17 +203,16 @@ class Base(object):
         return next_sib.prefix
 
     if sys.version_info < (3, 0):
         return next_sib.prefix
 
     if sys.version_info < (3, 0):
+
         def __str__(self):
             return str(self).encode("ascii")
 
         def __str__(self):
             return str(self).encode("ascii")
 
+
 class Node(Base):
 
     """Concrete implementation for interior nodes."""
 
 class Node(Base):
 
     """Concrete implementation for interior nodes."""
 
-    def __init__(self,type, children,
-                 context=None,
-                 prefix=None,
-                 fixers_applied=None):
+    def __init__(self, type, children, context=None, prefix=None, fixers_applied=None):
         """
         Initializer.
 
         """
         Initializer.
 
@@ -226,6 +227,7 @@ class Node(Base):
         for ch in self.children:
             assert ch.parent is None, repr(ch)
             ch.parent = self
         for ch in self.children:
             assert ch.parent is None, repr(ch)
             ch.parent = self
+        self.invalidate_sibling_maps()
         if prefix is not None:
             self.prefix = prefix
         if fixers_applied:
         if prefix is not None:
             self.prefix = prefix
         if fixers_applied:
@@ -235,9 +237,11 @@ class Node(Base):
 
     def __repr__(self):
         """Return a canonical string representation."""
 
     def __repr__(self):
         """Return a canonical string representation."""
-        return "%s(%s, %r)" % (self.__class__.__name__,
-                               type_repr(self.type),
-                               self.children)
+        return "%s(%s, %r)" % (
+            self.__class__.__name__,
+            type_repr(self.type),
+            self.children,
+        )
 
     def __unicode__(self):
         """
 
     def __unicode__(self):
         """
@@ -256,8 +260,11 @@ class Node(Base):
 
     def clone(self):
         """Return a cloned (deep) copy of self."""
 
     def clone(self):
         """Return a cloned (deep) copy of self."""
-        return Node(self.type, [ch.clone() for ch in self.children],
-                    fixers_applied=self.fixers_applied)
+        return Node(
+            self.type,
+            [ch.clone() for ch in self.children],
+            fixers_applied=self.fixers_applied,
+        )
 
     def post_order(self):
         """Return a post-order iterator for the tree."""
 
     def post_order(self):
         """Return a post-order iterator for the tree."""
@@ -294,6 +301,7 @@ class Node(Base):
         self.children[i].parent = None
         self.children[i] = child
         self.changed()
         self.children[i].parent = None
         self.children[i] = child
         self.changed()
+        self.invalidate_sibling_maps()
 
     def insert_child(self, i, child):
         """
 
     def insert_child(self, i, child):
         """
@@ -303,6 +311,7 @@ class Node(Base):
         child.parent = self
         self.children.insert(i, child)
         self.changed()
         child.parent = self
         self.children.insert(i, child)
         self.changed()
+        self.invalidate_sibling_maps()
 
     def append_child(self, child):
         """
 
     def append_child(self, child):
         """
@@ -312,6 +321,21 @@ class Node(Base):
         child.parent = self
         self.children.append(child)
         self.changed()
         child.parent = self
         self.children.append(child)
         self.changed()
+        self.invalidate_sibling_maps()
+
+    def invalidate_sibling_maps(self):
+        self.prev_sibling_map = None
+        self.next_sibling_map = None
+
+    def update_sibling_maps(self):
+        self.prev_sibling_map = _prev = {}
+        self.next_sibling_map = _next = {}
+        previous = None
+        for current in self.children:
+            _prev[id(current)] = previous
+            _next[id(previous)] = current
+            previous = current
+        _next[id(current)] = None
 
 
 class Leaf(Base):
 
 
 class Leaf(Base):
@@ -320,13 +344,10 @@ class Leaf(Base):
 
     # Default values for instance variables
     _prefix = ""  # Whitespace and comments preceding this token in the input
 
     # Default values for instance variables
     _prefix = ""  # Whitespace and comments preceding this token in the input
-    lineno = 0    # Line where this token starts in the input
-    column = 0    # Column where this token tarts in the input
+    lineno = 0  # Line where this token starts in the input
+    column = 0  # Column where this token starts in the input
 
 
-    def __init__(self, type, value,
-                 context=None,
-                 prefix=None,
-                 fixers_applied=[]):
+    def __init__(self, type, value, context=None, prefix=None, fixers_applied=[]):
         """
         Initializer.
 
         """
         Initializer.
 
@@ -345,9 +366,12 @@ class Leaf(Base):
     def __repr__(self):
         """Return a canonical string representation."""
         from .pgen2.token import tok_name
     def __repr__(self):
         """Return a canonical string representation."""
         from .pgen2.token import tok_name
-        return "%s(%s, %r)" % (self.__class__.__name__,
-                               tok_name.get(self.type, self.type),
-                               self.value)
+
+        return "%s(%s, %r)" % (
+            self.__class__.__name__,
+            tok_name.get(self.type, self.type),
+            self.value,
+        )
 
     def __unicode__(self):
         """
 
     def __unicode__(self):
         """
@@ -366,9 +390,12 @@ class Leaf(Base):
 
     def clone(self):
         """Return a cloned (deep) copy of self."""
 
     def clone(self):
         """Return a cloned (deep) copy of self."""
-        return Leaf(self.type, self.value,
-                    (self.prefix, (self.lineno, self.column)),
-                    fixers_applied=self.fixers_applied)
+        return Leaf(
+            self.type,
+            self.value,
+            (self.prefix, (self.lineno, self.column)),
+            fixers_applied=self.fixers_applied,
+        )
 
     def leaves(self):
         yield self
 
     def leaves(self):
         yield self
@@ -393,6 +420,7 @@ class Leaf(Base):
         self.changed()
         self._prefix = prefix
 
         self.changed()
         self._prefix = prefix
 
+
 def convert(gr, raw_node):
     """
     Convert raw node information to a Node or Leaf instance.
 def convert(gr, raw_node):
     """
     Convert raw node information to a Node or Leaf instance.
@@ -429,9 +457,9 @@ class BasePattern(object):
     """
 
     # Defaults for instance variables
     """
 
     # Defaults for instance variables
-    type = None     # Node type (token if < 256, symbol if >= 256)
+    type = None  # Node type (token if < 256, symbol if >= 256)
     content = None  # Optional content matching pattern
     content = None  # Optional content matching pattern
-    name = None     # Optional name used to store match in results dict
+    name = None  # Optional name used to store match in results dict
 
     def __new__(cls, *args, **kwds):
         """Constructor that prevents BasePattern from being instantiated."""
 
     def __new__(cls, *args, **kwds):
         """Constructor that prevents BasePattern from being instantiated."""
@@ -499,7 +527,6 @@ class BasePattern(object):
 
 
 class LeafPattern(BasePattern):
 
 
 class LeafPattern(BasePattern):
-
     def __init__(self, type=None, content=None, name=None):
         """
         Initializer.  Takes optional type, content, and name.
     def __init__(self, type=None, content=None, name=None):
         """
         Initializer.  Takes optional type, content, and name.
@@ -646,7 +673,7 @@ class WildcardPattern(BasePattern):
             # Check sanity of alternatives
             assert len(content), repr(content)  # Can't have zero alternatives
             for alt in content:
             # Check sanity of alternatives
             assert len(content), repr(content)  # Can't have zero alternatives
             for alt in content:
-                assert len(alt), repr(alt) # Can have empty alternatives
+                assert len(alt), repr(alt)  # Can have empty alternatives
         self.content = content
         self.min = min
         self.max = max
         self.content = content
         self.min = min
         self.max = max
@@ -655,20 +682,29 @@ class WildcardPattern(BasePattern):
     def optimize(self):
         """Optimize certain stacked wildcard patterns."""
         subpattern = None
     def optimize(self):
         """Optimize certain stacked wildcard patterns."""
         subpattern = None
-        if (self.content is not None and
-            len(self.content) == 1 and len(self.content[0]) == 1):
+        if (
+            self.content is not None
+            and len(self.content) == 1
+            and len(self.content[0]) == 1
+        ):
             subpattern = self.content[0][0]
         if self.min == 1 and self.max == 1:
             if self.content is None:
                 return NodePattern(name=self.name)
             subpattern = self.content[0][0]
         if self.min == 1 and self.max == 1:
             if self.content is None:
                 return NodePattern(name=self.name)
-            if subpattern is not None and  self.name == subpattern.name:
+            if subpattern is not None and self.name == subpattern.name:
                 return subpattern.optimize()
                 return subpattern.optimize()
-        if (self.min <= 1 and isinstance(subpattern, WildcardPattern) and
-            subpattern.min <= 1 and self.name == subpattern.name):
-            return WildcardPattern(subpattern.content,
-                                   self.min*subpattern.min,
-                                   self.max*subpattern.max,
-                                   subpattern.name)
+        if (
+            self.min <= 1
+            and isinstance(subpattern, WildcardPattern)
+            and subpattern.min <= 1
+            and self.name == subpattern.name
+        ):
+            return WildcardPattern(
+                subpattern.content,
+                self.min * subpattern.min,
+                self.max * subpattern.max,
+                subpattern.name,
+            )
         return self
 
     def match(self, node, results=None):
         return self
 
     def match(self, node, results=None):
@@ -784,7 +820,7 @@ class WildcardPattern(BasePattern):
         if count < self.max:
             for alt in self.content:
                 for c0, r0 in generate_matches(alt, nodes):
         if count < self.max:
             for alt in self.content:
                 for c0, r0 in generate_matches(alt, nodes):
-                    for c1, r1 in self._recursive_matches(nodes[c0:], count+1):
+                    for c1, r1 in self._recursive_matches(nodes[c0:], count + 1):
                         r = {}
                         r.update(r0)
                         r.update(r1)
                         r = {}
                         r.update(r0)
                         r.update(r1)
@@ -792,7 +828,6 @@ class WildcardPattern(BasePattern):
 
 
 class NegatedPattern(BasePattern):
 
 
 class NegatedPattern(BasePattern):
-
     def __init__(self, content=None):
         """
         Initializer.
     def __init__(self, content=None):
         """
         Initializer.