File size: 2,245 Bytes
165ee00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch

from ..utils import default_device


@torch.no_grad()
def get_batch(batch_size, seq_len, num_features, get_batch, model, single_eval_pos, epoch, device=default_device, hyperparameters={}, **kwargs):
    if hyperparameters.get('normalize_x', False):
        uniform_float = torch.rand(tuple()).clamp(.1,1.).item()
        new_hyperparameters = {**hyperparameters, 'sampling': uniform_float * hyperparameters['sampling']}
    else:
        new_hyperparameters = hyperparameters
    returns = get_batch(batch_size=batch_size, seq_len=seq_len,
                              num_features=num_features, device=device,
                              hyperparameters=new_hyperparameters, model=model,
                              single_eval_pos=single_eval_pos, epoch=epoch, **kwargs)

    style = []

    if normalize_x_mode := hyperparameters.get('normalize_x', False):
        returns.x, mean_style, std_style = normalize_data_by_first_k(returns.x, single_eval_pos if normalize_x_mode == 'train' else len(returns.x))
        if hyperparameters.get('style_includes_mean_from_normalization', True) or normalize_x_mode == 'train':
             style.append(mean_style)
        style.append(std_style)

    if hyperparameters.get('normalize_y', False):
        returns.y, mean_style, std_style = normalize_data_by_first_k(returns.y, single_eval_pos)
        style += [mean_style, std_style]

    returns.style = torch.cat(style,1) if style else None
    return returns


def normalize_data_by_first_k(x, k):
    # x has shape seq_len, batch_size, num_features or seq_len, num_features
    # k is the number of elements to normalize by
    unsqueezed_x = False
    if len(x.shape) == 2:
        x.unsqueeze_(2)
        unsqueezed_x = True

    if k > 1:
        relevant_x = x[:k]
        mean_style = relevant_x.mean(0)
        std_style = relevant_x.std(0)
        x = (x - relevant_x.mean(0, keepdim=True)) / relevant_x.std(0, keepdim=True)
    elif k == 1:
        mean_style = x[0]
        std_style = torch.ones_like(x[0])
        x = (x - x[0])
    else: # it is 0
        mean_style = torch.zeros_like(x[0])
        std_style = torch.ones_like(x[0])

    if unsqueezed_x:
        x.squeeze_(2)

    return x, mean_style, std_style