File size: 7,305 Bytes
165ee00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import torch

from .prior import Batch
from ..utils import default_device


loaded_models = {}

def get_model(model_name, device):
    if model_name not in loaded_models:
        import submitit
        group, index = model_name.split(':')
        ex = submitit.get_executor()
        model = ex.get_group(group)[int(index)].results()[0][2]
        model.to(device)
        loaded_models[model_name] = model
    return loaded_models[model_name]

@torch.no_grad()
def get_batch(batch_size, seq_len, num_features, get_batch, model, single_eval_pos, epoch, device=default_device, hyperparameters=None, **kwargs):
    """
    Important Assumptions:
        'inf_batch_size', 'max_level', 'sample_only_one_level', 'eval_seq_len' and 'epochs_per_level' in hyperparameters

    You can train a new model, based on an old one to only sample from a single level.
    You specify `level_0_model` as a group:index string and the model will be loaded from the checkpoint.



    :param batch_size:
    :param seq_len:
    :param num_features:
    :param get_batch:
    :param model:
    :param single_eval_pos:
    :param epoch:
    :param device:
    :param hyperparameters:
    :param kwargs:
    :return:
    """
    if level_0_model := hyperparameters.get('level_0_model', None):
        assert hyperparameters['sample_only_one_level'], "level_0_model only makes sense if you sample only one level"
        assert hyperparameters['max_level'] == 1, "level_0_model only makes sense if you sample only one level"
        level_0_model = get_model(level_0_model, device)
        model = level_0_model

    # the level describes how many fantasized steps are possible. This starts at 0 for the first epochs.
    epochs_per_level = hyperparameters['epochs_per_level']
    share_predict_mean_distribution = hyperparameters.get('share_predict_mean_distribution', 0.)
    use_mean_prediction = share_predict_mean_distribution or\
                          (model.decoder_dict_once is not None and 'mean_prediction' in model.decoder_dict_once)
    num_evals = seq_len - single_eval_pos
    level = min(min(epoch // epochs_per_level, hyperparameters['max_level']), num_evals - 1)
    if level_0_model:
        level = 1
    eval_seq_len = hyperparameters['eval_seq_len']
    add_seq_len = 0 if use_mean_prediction else eval_seq_len
    long_seq_len = seq_len + add_seq_len

    if level_0_model:
        styles = torch.ones(batch_size, 1, device=device, dtype=torch.long)
    elif hyperparameters['sample_only_one_level']:
        styles = torch.randint(level + 1, (1, 1), device=device).repeat(batch_size, 1)  # styles are sorted :)
    else:
        styles = torch.randint(level + 1, (batch_size,1), device=device).sort(0).values # styles are sorted :)

    predict_mean_distribution = None
    if share_predict_mean_distribution:
        max_used_level = max(styles)
        # below code assumes epochs are base 0!
        share_of_training = epoch / epochs_per_level
        #print(share_of_training, (max_used_level + 1. - share_predict_mean_distribution), max_used_level, level, epoch)
        predict_mean_distribution = (share_of_training >= (max_used_level + 1. - share_predict_mean_distribution)) and (max_used_level < hyperparameters['max_level'])

    x, y, targets = [], [], []

    for considered_level in range(level+1):
        num_elements = (styles == considered_level).sum()
        if not num_elements:
            continue
        returns: Batch = get_batch(batch_size=num_elements, seq_len=long_seq_len,
                                   num_features=num_features, device=device,
                                   hyperparameters=hyperparameters, model=model,
                                   single_eval_pos=single_eval_pos, epoch=epoch,
                                   **kwargs)
        levels_x, levels_y, levels_targets = returns.x, returns.y, returns.target_y
        assert not returns.other_filled_attributes(), f"Unexpected filled attributes: {returns.other_filled_attributes()}"

        assert levels_y is levels_targets
        levels_targets = levels_targets.clone()
        if len(levels_y.shape) == 2:
            levels_y = levels_y.unsqueeze(2)
            levels_targets = levels_targets.unsqueeze(2)
        if considered_level > 0:

            feed_x = levels_x[:single_eval_pos + 1 + add_seq_len].repeat(1, num_evals, 1)
            feed_x[single_eval_pos, :] = levels_x[single_eval_pos:seq_len].reshape(-1, *levels_x.shape[2:])
            if not use_mean_prediction:
                feed_x[single_eval_pos + 1:] = levels_x[seq_len:].repeat(1, num_evals, 1)

            feed_y = levels_y[:single_eval_pos + 1 + add_seq_len].repeat(1, num_evals, 1)
            feed_y[single_eval_pos, :] = levels_y[single_eval_pos:seq_len].reshape(-1, *levels_y.shape[2:])
            if not use_mean_prediction:
                feed_y[single_eval_pos + 1:] = levels_y[seq_len:].repeat(1, num_evals, 1)

            model.eval()
            means = []
            for feed_x_b, feed_y_b in zip(torch.split(feed_x, hyperparameters['inf_batch_size'], dim=1),
                                          torch.split(feed_y, hyperparameters['inf_batch_size'], dim=1)):
                with torch.cuda.amp.autocast():
                    style = torch.zeros(feed_x_b.shape[1], 1, dtype=torch.int64, device=device) + considered_level - 1
                    if level_0_model is not None and level_0_model.style_encoder is None:
                        style = None
                    out = model(
                        (style, feed_x_b, feed_y_b),
                        single_eval_pos=single_eval_pos+1, only_return_standard_out=False
                    )
                    if isinstance(out, tuple):
                        output, once_output = out
                    else:
                        output = out
                        once_output = {}

                if once_output and 'mean_prediction' in once_output:
                    mean_pred_logits = once_output['mean_prediction'].float()
                    assert tuple(mean_pred_logits.shape) == (feed_x_b.shape[1], model.criterion.num_bars),\
                        f"{tuple(mean_pred_logits.shape)} vs {(feed_x_b.shape[1], model.criterion.num_bars)}"
                    means.append(model.criterion.icdf(mean_pred_logits, 1.-1./eval_seq_len))
                else:
                    logits = output['standard'].float()
                    means.append(model.criterion.mean(logits).max(0).values)
            means = torch.cat(means, 0)
            levels_targets_new = means.view(seq_len-single_eval_pos, *levels_y.shape[1:])
            levels_targets[single_eval_pos:seq_len] = levels_targets_new #- levels_targets_new.mean(0)
            model.train()

        levels_x = levels_x[:seq_len]
        levels_y = levels_y[:seq_len]
        levels_targets = levels_targets[:seq_len]

        x.append(levels_x)
        y.append(levels_y)
        targets.append(levels_targets)

    x = torch.cat(x, 1)
    # if predict_mean_distribution: print(f'predict mean dist in b, {epoch=}, {max_used_level=}')
    return Batch(x=x, y=torch.cat(y, 1), target_y=torch.cat(targets, 1), style=styles,
                 mean_prediction=predict_mean_distribution.item() if predict_mean_distribution is not None else None)