from enum import Enum, Flag
 from functools import lru_cache, partial, wraps
 import io
+import itertools
 import keyword
 import logging
 from multiprocessing import Manager
 
     depth: int = 0
     leaves: List[Leaf] = Factory(list)
-    comments: List[Tuple[Index, Leaf]] = Factory(list)
+    # The LeafID keys of comments must remain ordered by the corresponding leaf's index
+    # in leaves
+    comments: Dict[LeafID, List[Leaf]] = Factory(dict)
     bracket_tracker: BracketTracker = Factory(BracketTracker)
     inside_brackets: bool = False
     should_explode: bool = False
         if comment.type != token.COMMENT:
             return False
 
-        after = len(self.leaves) - 1
-        if after == -1:
+        if not self.leaves:
             comment.type = STANDALONE_COMMENT
             comment.prefix = ""
             return False
 
         else:
-            self.comments.append((after, comment))
-            return True
-
-    def comments_after(self, leaf: Leaf, _index: int = -1) -> Iterator[Leaf]:
-        """Generate comments that should appear directly after `leaf`.
-
-        Provide a non-negative leaf `_index` to speed up the function.
-        """
-        if not self.comments:
-            return
-
-        if _index == -1:
-            for _index, _leaf in enumerate(self.leaves):
-                if leaf is _leaf:
-                    break
-
+            leaf_id = id(self.leaves[-1])
+            if leaf_id not in self.comments:
+                self.comments[leaf_id] = [comment]
             else:
-                return
+                self.comments[leaf_id].append(comment)
+            return True
 
-        for index, comment_after in self.comments:
-            if _index == index:
-                yield comment_after
+    def comments_after(self, leaf: Leaf) -> List[Leaf]:
+        """Generate comments that should appear directly after `leaf`."""
+        return self.comments.get(id(leaf), [])
 
     def remove_trailing_comma(self) -> None:
         """Remove the trailing comma and moves the comments attached to it."""
-        comma_index = len(self.leaves) - 1
-        for i in range(len(self.comments)):
-            comment_index, comment = self.comments[i]
-            if comment_index == comma_index:
-                self.comments[i] = (comma_index - 1, comment)
+        # Remember, the LeafID keys of self.comments are ordered by the
+        # corresponding leaf's index in self.leaves
+        # If id(self.leaves[-2]) is in self.comments, the order doesn't change.
+        # Otherwise, we insert it into self.comments, and it becomes the last entry.
+        # However, since we delete id(self.leaves[-1]) from self.comments, the invariant
+        # is maintained
+        self.comments.setdefault(id(self.leaves[-2]), []).extend(
+            self.comments.get(id(self.leaves[-1]), [])
+        )
+        self.comments.pop(id(self.leaves[-1]), None)
         self.leaves.pop()
 
     def is_complex_subscript(self, leaf: Leaf) -> bool:
         res = f"{first.prefix}{indent}{first.value}"
         for leaf in leaves:
             res += str(leaf)
-        for _, comment in self.comments:
+        for comment in itertools.chain.from_iterable(self.comments.values()):
             res += str(comment)
         return res + "\n"
 
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             current_line.append(leaf)
 
-    for index, leaf in enumerate(line.leaves):
+    for leaf in line.leaves:
         yield from append_to_line(leaf)
 
-        for comment_after in line.comments_after(leaf, index):
+        for comment_after in line.comments_after(leaf):
             yield from append_to_line(comment_after)
 
         lowest_depth = min(lowest_depth, leaf.bracket_depth)
             current_line = Line(depth=line.depth, inside_brackets=line.inside_brackets)
             current_line.append(leaf)
 
-    for index, leaf in enumerate(line.leaves):
+    for leaf in line.leaves:
         yield from append_to_line(leaf)
 
-        for comment_after in line.comments_after(leaf, index):
+        for comment_after in line.comments_after(leaf):
             yield from append_to_line(comment_after)
 
     if current_line:
             return  # Multiline strings, we can't continue.
 
         comment: Optional[Leaf]
-        for comment in line.comments_after(leaf, index):
+        for comment in line.comments_after(leaf):
             length += len(comment.value)
 
         yield index, leaf, length