IsshikiHugh's picture
feat: CPU demo
5ac1897
import math
from tqdm import tqdm
from contextlib import contextmanager
from typing import Tuple, Union
from ..basic import BasicBatchWindow, bsb
class AdaptableBatchWindow(BasicBatchWindow):
def __init__(self, sid, eid, min_B):
self.start_id = sid
self.end_id = eid
self.min_B = min_B
self.shrinking = False
def shrink(self):
if self.size <= self.min_B:
return False
else:
self.shrinking = True
return True
class asb(bsb):
def __init__(
self,
total : int,
bs_scope : Union[Tuple[int, int], int],
enable_tqdm : bool = False,
):
''' Simple binary strategy. '''
# Static hyperparameters.
self.total = int(total)
if isinstance(bs_scope, int):
self.min_B = 1
self.max_B = bs_scope
else:
self.min_B, self.max_B = bs_scope # lower & upper bound of batch size
# Dynamic state.
self.B = self.max_B # current batch size
self.tqdm = tqdm(total=self.total) if enable_tqdm else None
self.cur_window = AdaptableBatchWindow(sid=-1, eid=0, min_B=self.min_B) # starting window
self.last_shrink_id = None
def __next__(self):
if self.cur_window.shrinking:
sid = self.cur_window.sid
self.shrink_B(sid)
else:
sid = self.cur_window.eid
self.recover_B(sid)
if sid >= self.total:
if self.tqdm: self.tqdm.close()
raise StopIteration
eid = min(sid + self.B, self.total)
self.cur_window = AdaptableBatchWindow(sid, eid, min_B=self.min_B)
if self.tqdm: self.tqdm.update(eid - sid)
return self.cur_window
def shrink_B(self, cur_id:int):
self.last_shrink_id = cur_id
self.cur_window.shrinking = False
self.B = max(math.ceil(self.B/2), self.min_B)
def recover_B(self, cur_id:int):
if self.last_shrink_id and self.B < self.max_B:
newer_B = min(self.B * 2, self.max_B)
if self.last_shrink_id < cur_id - newer_B:
self.B = newer_B