File size: 2,177 Bytes
5ac1897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import math
from tqdm import tqdm
from contextlib import contextmanager
from typing import Tuple, Union
from ..basic import BasicBatchWindow, bsb

class AdaptableBatchWindow(BasicBatchWindow):
    def __init__(self, sid, eid, min_B):
        self.start_id = sid
        self.end_id = eid
        self.min_B = min_B
        self.shrinking = False

    def shrink(self):
        if self.size <= self.min_B:
            return False
        else:
            self.shrinking = True
            return True


class asb(bsb):
    def __init__(
        self,
        total       : int,
        bs_scope    : Union[Tuple[int, int], int],
        enable_tqdm : bool = False,
    ):
        ''' Simple binary strategy. '''
        # Static hyperparameters.
        self.total = int(total)
        if isinstance(bs_scope, int):
            self.min_B = 1
            self.max_B = bs_scope
        else:
            self.min_B, self.max_B = bs_scope  # lower & upper bound of batch size
        # Dynamic state.
        self.B = self.max_B  # current batch size
        self.tqdm = tqdm(total=self.total) if enable_tqdm else None
        self.cur_window = AdaptableBatchWindow(sid=-1, eid=0, min_B=self.min_B)  # starting window
        self.last_shrink_id = None

    def __next__(self):
        if self.cur_window.shrinking:
            sid = self.cur_window.sid
            self.shrink_B(sid)
        else:
            sid = self.cur_window.eid
            self.recover_B(sid)

        if sid >= self.total:
            if self.tqdm: self.tqdm.close()
            raise StopIteration

        eid = min(sid + self.B, self.total)
        self.cur_window = AdaptableBatchWindow(sid, eid, min_B=self.min_B)
        if self.tqdm: self.tqdm.update(eid - sid)
        return self.cur_window

    def shrink_B(self, cur_id:int):
        self.last_shrink_id = cur_id
        self.cur_window.shrinking = False
        self.B = max(math.ceil(self.B/2), self.min_B)

    def recover_B(self, cur_id:int):
        if self.last_shrink_id and self.B < self.max_B:
            newer_B = min(self.B * 2, self.max_B)
            if self.last_shrink_id < cur_id - newer_B:
                self.B = newer_B