File size: 4,738 Bytes
165ee00
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import random
import math

import torch
from torch import nn
import time
import numpy as np

from ..utils import default_device
from .prior import Batch
from .utils import get_batch_to_dataloader


class MLP(torch.nn.Module):
    def __init__(self, num_inputs, num_layers, num_hidden, num_outputs, init_std=None, sparseness=0.0,
                 preactivation_noise_std=0.0, activation='tanh'):
        super(MLP, self).__init__()
        self.linears = nn.ModuleList(
            [nn.Linear(num_inputs, num_hidden)] + \
            [nn.Linear(num_hidden,num_hidden) for _ in range(num_layers-2)] + \
            [nn.Linear(num_hidden,num_outputs)]
        )

        self.init_std = init_std
        self.sparseness = sparseness
        self.reset_parameters()

        self.preactivation_noise_std = preactivation_noise_std
        self.activation = {
            'tanh': torch.nn.Tanh(),
            'relu': torch.nn.ReLU(),
            'elu': torch.nn.ELU(),
            'identity': torch.nn.Identity(),
        }[activation]

    def reset_parameters(self, init_std=None, sparseness=None):
        init_std = init_std if init_std is not None else self.init_std
        sparseness = sparseness if sparseness is not None else self.sparseness
        for linear in self.linears:
            linear.reset_parameters()

        with torch.no_grad():
            if init_std is not None:
                for linear in self.linears:
                    linear.weight.normal_(0, init_std)
                    linear.bias.normal_(0, init_std)

            if sparseness > 0.0:
                for linear in self.linears[1:-1]:
                    linear.weight /= (1. - sparseness) ** (1 / 2)
                    linear.weight *= torch.bernoulli(torch.ones_like(linear.weight) * (1. - sparseness))

    def forward(self, x):
        for linear in self.linears[:-1]:
            x = linear(x)
            x = x + torch.randn_like(x) * self.preactivation_noise_std
            x = torch.tanh(x)
        x = self.linears[-1](x)
        return x


def sample_input(input_sampling_setting, batch_size, seq_len, num_features, device=default_device):
    if input_sampling_setting == 'normal':
        x = torch.randn(batch_size, seq_len, num_features, device=device)
        x_for_mlp = x
    elif input_sampling_setting == 'uniform':
        x = torch.rand(batch_size, seq_len, num_features, device=device)
        x_for_mlp = (x - .5)/math.sqrt(1/12)
    else:
        raise ValueError(f"Unknown input_sampling: {input_sampling_setting}")
    return x, x_for_mlp


@torch.no_grad()
def get_batch(batch_size, seq_len, num_features, hyperparameters, device=default_device, num_outputs=1, **kwargs):
    if hyperparameters is None:
        hyperparameters = {
            'mlp_num_layers': 2,
            'mlp_num_hidden': 64,
            'mlp_init_std': 0.1,
            'mlp_sparseness': 0.2,
            'mlp_input_sampling': 'normal',
            'mlp_output_noise': 0.0,
            'mlp_noisy_targets': False,
            'mlp_preactivation_noise_std': 0.0,
        }

    x, x_for_mlp = sample_input(hyperparameters.get('mlp_input_sampling', 'normal'), batch_size, seq_len, num_features,
                     device=device)

    model = MLP(num_features, hyperparameters['mlp_num_layers'], hyperparameters['mlp_num_hidden'],
                num_outputs, hyperparameters['mlp_init_std'], hyperparameters['mlp_sparseness'],
                hyperparameters['mlp_preactivation_noise_std'], hyperparameters.get('activation', 'tanh')).to(device)

    no_noise_model = MLP(num_features, hyperparameters['mlp_num_layers'], hyperparameters['mlp_num_hidden'],
                num_outputs, hyperparameters['mlp_init_std'], hyperparameters['mlp_sparseness'],
                0., hyperparameters.get('activation', 'tanh')).to(device)

    ys = []
    targets = []
    for x_ in x_for_mlp:
        model.reset_parameters()
        y = model(x_ / math.sqrt(num_features))
        ys.append(y.unsqueeze(1))
        if not hyperparameters.get('mlp_preactivation_noise_in_targets', True):
            assert not hyperparameters['mlp_noisy_targets']
            no_noise_model.load_state_dict(model.state_dict())
            target = no_noise_model(x_ / math.sqrt(num_features))
            targets.append(target.unsqueeze(1))

    y = torch.cat(ys, dim=1)
    targets = torch.cat(targets, dim=1) if targets else y

    noisy_y = y + torch.randn_like(y) * hyperparameters['mlp_output_noise']

    #return x.transpose(0, 1), noisy_y, (noisy_y if hyperparameters['mlp_noisy_targets'] else targets)
    return Batch(x.transpose(0, 1), noisy_y, (noisy_y if hyperparameters['mlp_noisy_targets'] else targets))

DataLoader = get_batch_to_dataloader(get_batch)