File size: 1,167 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from tqdm import tqdm

class BasicBatchWindow():
    def __init__(self, sid, eid):
        self.start_id = sid
        self.end_id = eid

    @property
    def sid(self):
        return self.start_id

    @property
    def eid(self):
        return self.end_id

    @property
    def size(self):
        return self.eid - self.sid


class bsb():
    def __init__(
        self,
        total       : int,
        batch_size  : int,
        enable_tqdm : bool = False,
    ):
        # Static hyperparameters.
        self.total = int(total)
        self.B = batch_size
        # Dynamic state.
        self.tqdm = tqdm(total=self.total) if enable_tqdm else None
        self.cur_window = BasicBatchWindow(-1, 0)  # starting window

    def __iter__(self):
        return self

    def __next__(self):
        if self.cur_window.eid >= self.total:
            if self.tqdm: self.tqdm.close()
            raise StopIteration
        if self.tqdm: self.tqdm.update(self.cur_window.eid - self.cur_window.sid)

        sid = self.cur_window.eid
        eid = min(sid + self.B, self.total)
        self.cur_window = BasicBatchWindow(sid, eid)
        return self.cur_window