Spaces:
Sleeping
Sleeping
File size: 7,305 Bytes
165ee00 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 |
import torch
from .prior import Batch
from ..utils import default_device
loaded_models = {}
def get_model(model_name, device):
if model_name not in loaded_models:
import submitit
group, index = model_name.split(':')
ex = submitit.get_executor()
model = ex.get_group(group)[int(index)].results()[0][2]
model.to(device)
loaded_models[model_name] = model
return loaded_models[model_name]
@torch.no_grad()
def get_batch(batch_size, seq_len, num_features, get_batch, model, single_eval_pos, epoch, device=default_device, hyperparameters=None, **kwargs):
"""
Important Assumptions:
'inf_batch_size', 'max_level', 'sample_only_one_level', 'eval_seq_len' and 'epochs_per_level' in hyperparameters
You can train a new model, based on an old one to only sample from a single level.
You specify `level_0_model` as a group:index string and the model will be loaded from the checkpoint.
:param batch_size:
:param seq_len:
:param num_features:
:param get_batch:
:param model:
:param single_eval_pos:
:param epoch:
:param device:
:param hyperparameters:
:param kwargs:
:return:
"""
if level_0_model := hyperparameters.get('level_0_model', None):
assert hyperparameters['sample_only_one_level'], "level_0_model only makes sense if you sample only one level"
assert hyperparameters['max_level'] == 1, "level_0_model only makes sense if you sample only one level"
level_0_model = get_model(level_0_model, device)
model = level_0_model
# the level describes how many fantasized steps are possible. This starts at 0 for the first epochs.
epochs_per_level = hyperparameters['epochs_per_level']
share_predict_mean_distribution = hyperparameters.get('share_predict_mean_distribution', 0.)
use_mean_prediction = share_predict_mean_distribution or\
(model.decoder_dict_once is not None and 'mean_prediction' in model.decoder_dict_once)
num_evals = seq_len - single_eval_pos
level = min(min(epoch // epochs_per_level, hyperparameters['max_level']), num_evals - 1)
if level_0_model:
level = 1
eval_seq_len = hyperparameters['eval_seq_len']
add_seq_len = 0 if use_mean_prediction else eval_seq_len
long_seq_len = seq_len + add_seq_len
if level_0_model:
styles = torch.ones(batch_size, 1, device=device, dtype=torch.long)
elif hyperparameters['sample_only_one_level']:
styles = torch.randint(level + 1, (1, 1), device=device).repeat(batch_size, 1) # styles are sorted :)
else:
styles = torch.randint(level + 1, (batch_size,1), device=device).sort(0).values # styles are sorted :)
predict_mean_distribution = None
if share_predict_mean_distribution:
max_used_level = max(styles)
# below code assumes epochs are base 0!
share_of_training = epoch / epochs_per_level
#print(share_of_training, (max_used_level + 1. - share_predict_mean_distribution), max_used_level, level, epoch)
predict_mean_distribution = (share_of_training >= (max_used_level + 1. - share_predict_mean_distribution)) and (max_used_level < hyperparameters['max_level'])
x, y, targets = [], [], []
for considered_level in range(level+1):
num_elements = (styles == considered_level).sum()
if not num_elements:
continue
returns: Batch = get_batch(batch_size=num_elements, seq_len=long_seq_len,
num_features=num_features, device=device,
hyperparameters=hyperparameters, model=model,
single_eval_pos=single_eval_pos, epoch=epoch,
**kwargs)
levels_x, levels_y, levels_targets = returns.x, returns.y, returns.target_y
assert not returns.other_filled_attributes(), f"Unexpected filled attributes: {returns.other_filled_attributes()}"
assert levels_y is levels_targets
levels_targets = levels_targets.clone()
if len(levels_y.shape) == 2:
levels_y = levels_y.unsqueeze(2)
levels_targets = levels_targets.unsqueeze(2)
if considered_level > 0:
feed_x = levels_x[:single_eval_pos + 1 + add_seq_len].repeat(1, num_evals, 1)
feed_x[single_eval_pos, :] = levels_x[single_eval_pos:seq_len].reshape(-1, *levels_x.shape[2:])
if not use_mean_prediction:
feed_x[single_eval_pos + 1:] = levels_x[seq_len:].repeat(1, num_evals, 1)
feed_y = levels_y[:single_eval_pos + 1 + add_seq_len].repeat(1, num_evals, 1)
feed_y[single_eval_pos, :] = levels_y[single_eval_pos:seq_len].reshape(-1, *levels_y.shape[2:])
if not use_mean_prediction:
feed_y[single_eval_pos + 1:] = levels_y[seq_len:].repeat(1, num_evals, 1)
model.eval()
means = []
for feed_x_b, feed_y_b in zip(torch.split(feed_x, hyperparameters['inf_batch_size'], dim=1),
torch.split(feed_y, hyperparameters['inf_batch_size'], dim=1)):
with torch.cuda.amp.autocast():
style = torch.zeros(feed_x_b.shape[1], 1, dtype=torch.int64, device=device) + considered_level - 1
if level_0_model is not None and level_0_model.style_encoder is None:
style = None
out = model(
(style, feed_x_b, feed_y_b),
single_eval_pos=single_eval_pos+1, only_return_standard_out=False
)
if isinstance(out, tuple):
output, once_output = out
else:
output = out
once_output = {}
if once_output and 'mean_prediction' in once_output:
mean_pred_logits = once_output['mean_prediction'].float()
assert tuple(mean_pred_logits.shape) == (feed_x_b.shape[1], model.criterion.num_bars),\
f"{tuple(mean_pred_logits.shape)} vs {(feed_x_b.shape[1], model.criterion.num_bars)}"
means.append(model.criterion.icdf(mean_pred_logits, 1.-1./eval_seq_len))
else:
logits = output['standard'].float()
means.append(model.criterion.mean(logits).max(0).values)
means = torch.cat(means, 0)
levels_targets_new = means.view(seq_len-single_eval_pos, *levels_y.shape[1:])
levels_targets[single_eval_pos:seq_len] = levels_targets_new #- levels_targets_new.mean(0)
model.train()
levels_x = levels_x[:seq_len]
levels_y = levels_y[:seq_len]
levels_targets = levels_targets[:seq_len]
x.append(levels_x)
y.append(levels_y)
targets.append(levels_targets)
x = torch.cat(x, 1)
# if predict_mean_distribution: print(f'predict mean dist in b, {epoch=}, {max_used_level=}')
return Batch(x=x, y=torch.cat(y, 1), target_y=torch.cat(targets, 1), style=styles,
mean_prediction=predict_mean_distribution.item() if predict_mean_distribution is not None else None)
|