+@dataclass
+class ReleaseRange:
+ start: int
+ end: Optional[int] = None
+ tokens: List[Any] = field(default_factory=list)
+
+ def lock(self) -> None:
+ total_eaten = len(self.tokens)
+ self.end = self.start + total_eaten
+
+
+class TokenProxy:
+ def __init__(self, generator: Any) -> None:
+ self._tokens = generator
+ self._counter = 0
+ self._release_ranges: List[ReleaseRange] = []
+
+ @contextmanager
+ def release(self) -> Iterator["TokenProxy"]:
+ release_range = ReleaseRange(self._counter)
+ self._release_ranges.append(release_range)
+ try:
+ yield self
+ finally:
+ # Lock the last release range to the final position that
+ # has been eaten.
+ release_range.lock()
+
+ def eat(self, point: int) -> Any:
+ eaten_tokens = self._release_ranges[-1].tokens
+ if point < len(eaten_tokens):
+ return eaten_tokens[point]
+ else:
+ while point >= len(eaten_tokens):
+ token = next(self._tokens)
+ eaten_tokens.append(token)
+ return token
+
+ def __iter__(self) -> "TokenProxy":
+ return self
+
+ def __next__(self) -> Any:
+ # If the current position is already compromised (looked up)
+ # return the eaten token, if not just go further on the given
+ # token producer.
+ for release_range in self._release_ranges:
+ assert release_range.end is not None
+
+ start, end = release_range.start, release_range.end
+ if start <= self._counter < end:
+ token = release_range.tokens[self._counter - start]
+ break
+ else:
+ token = next(self._tokens)
+ self._counter += 1
+ return token
+
+ def can_advance(self, to: int) -> bool:
+ # Try to eat, fail if it can't. The eat operation is cached
+ # so there won't be any additional cost of eating here
+ try:
+ self.eat(to)
+ except StopIteration:
+ return False
+ else:
+ return True
+
+