Spaces:
Starting
on
L4
Starting
on
L4
import math | |
from tqdm import tqdm | |
from contextlib import contextmanager | |
from typing import Tuple, Union | |
from ..basic import BasicBatchWindow, bsb | |
class AdaptableBatchWindow(BasicBatchWindow): | |
def __init__(self, sid, eid, min_B): | |
self.start_id = sid | |
self.end_id = eid | |
self.min_B = min_B | |
self.shrinking = False | |
def shrink(self): | |
if self.size <= self.min_B: | |
return False | |
else: | |
self.shrinking = True | |
return True | |
class asb(bsb): | |
def __init__( | |
self, | |
total : int, | |
bs_scope : Union[Tuple[int, int], int], | |
enable_tqdm : bool = False, | |
): | |
''' Simple binary strategy. ''' | |
# Static hyperparameters. | |
self.total = int(total) | |
if isinstance(bs_scope, int): | |
self.min_B = 1 | |
self.max_B = bs_scope | |
else: | |
self.min_B, self.max_B = bs_scope # lower & upper bound of batch size | |
# Dynamic state. | |
self.B = self.max_B # current batch size | |
self.tqdm = tqdm(total=self.total) if enable_tqdm else None | |
self.cur_window = AdaptableBatchWindow(sid=-1, eid=0, min_B=self.min_B) # starting window | |
self.last_shrink_id = None | |
def __next__(self): | |
if self.cur_window.shrinking: | |
sid = self.cur_window.sid | |
self.shrink_B(sid) | |
else: | |
sid = self.cur_window.eid | |
self.recover_B(sid) | |
if sid >= self.total: | |
if self.tqdm: self.tqdm.close() | |
raise StopIteration | |
eid = min(sid + self.B, self.total) | |
self.cur_window = AdaptableBatchWindow(sid, eid, min_B=self.min_B) | |
if self.tqdm: self.tqdm.update(eid - sid) | |
return self.cur_window | |
def shrink_B(self, cur_id:int): | |
self.last_shrink_id = cur_id | |
self.cur_window.shrinking = False | |
self.B = max(math.ceil(self.B/2), self.min_B) | |
def recover_B(self, cur_id:int): | |
if self.last_shrink_id and self.B < self.max_B: | |
newer_B = min(self.B * 2, self.max_B) | |
if self.last_shrink_id < cur_id - newer_B: | |
self.B = newer_B |