File size: 1,458 Bytes
fc3399e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from pydantic import ConfigDict

from bytelatent.data.iterators.abstract_iterator import (
    PydanticIteratorState,
    StatefulIterator,
)
from bytelatent.data.iterators.arrow_iterator import ArrowFileIteratorState
from bytelatent.data.iterators.dev_iterators import BltTestIteratorState


class LimitIteratorState(PydanticIteratorState):
    model_config = ConfigDict(extra="forbid")
    base_iterator_state: (
        BltTestIteratorState | ArrowFileIteratorState | PydanticIteratorState
    )
    n_yielded: int
    limit: int

    def build(self) -> "LimitIterator":
        return LimitIterator(
            base_iterator=self.base_iterator_state.build(),
            n_yielded=self.n_yielded,
            limit=self.limit,
        )


class LimitIterator(StatefulIterator):
    def __init__(self, base_iterator: StatefulIterator, limit: int, n_yielded: int = 0):
        self.base_iterator = base_iterator
        self.n_yielded = n_yielded
        self.limit = limit

    def get_state(self):
        return LimitIteratorState(
            base_iterator_state=self.base_iterator.get_state(),
            n_yielded=self.n_yielded,
            limit=self.limit,
        )

    def create_iter(self):
        iterator = self.base_iterator.create_iter()
        try:
            while self.n_yielded < self.limit or self.limit < 0:
                yield next(iterator)
                self.n_yielded += 1
        except StopIteration:
            pass