Spaces:
Sleeping
Sleeping
from tqdm import tqdm | |
class BasicBatchWindow(): | |
def __init__(self, sid, eid): | |
self.start_id = sid | |
self.end_id = eid | |
def sid(self): | |
return self.start_id | |
def eid(self): | |
return self.end_id | |
def size(self): | |
return self.eid - self.sid | |
class bsb(): | |
def __init__( | |
self, | |
total : int, | |
batch_size : int, | |
enable_tqdm : bool = False, | |
): | |
# Static hyperparameters. | |
self.total = int(total) | |
self.B = batch_size | |
# Dynamic state. | |
self.tqdm = tqdm(total=self.total) if enable_tqdm else None | |
self.cur_window = BasicBatchWindow(-1, 0) # starting window | |
def __iter__(self): | |
return self | |
def __next__(self): | |
if self.cur_window.eid >= self.total: | |
if self.tqdm: self.tqdm.close() | |
raise StopIteration | |
if self.tqdm: self.tqdm.update(self.cur_window.eid - self.cur_window.sid) | |
sid = self.cur_window.eid | |
eid = min(sid + self.B, self.total) | |
self.cur_window = BasicBatchWindow(sid, eid) | |
return self.cur_window |