Spaces:
Sleeping
Sleeping
import random | |
import math | |
import torch | |
from torch import nn | |
import time | |
import numpy as np | |
from ..utils import default_device | |
from .prior import Batch | |
from .utils import get_batch_to_dataloader | |
class MLP(torch.nn.Module): | |
def __init__(self, num_inputs, num_layers, num_hidden, num_outputs, init_std=None, sparseness=0.0, | |
preactivation_noise_std=0.0, activation='tanh'): | |
super(MLP, self).__init__() | |
self.linears = nn.ModuleList( | |
[nn.Linear(num_inputs, num_hidden)] + \ | |
[nn.Linear(num_hidden,num_hidden) for _ in range(num_layers-2)] + \ | |
[nn.Linear(num_hidden,num_outputs)] | |
) | |
self.init_std = init_std | |
self.sparseness = sparseness | |
self.reset_parameters() | |
self.preactivation_noise_std = preactivation_noise_std | |
self.activation = { | |
'tanh': torch.nn.Tanh(), | |
'relu': torch.nn.ReLU(), | |
'elu': torch.nn.ELU(), | |
'identity': torch.nn.Identity(), | |
}[activation] | |
def reset_parameters(self, init_std=None, sparseness=None): | |
init_std = init_std if init_std is not None else self.init_std | |
sparseness = sparseness if sparseness is not None else self.sparseness | |
for linear in self.linears: | |
linear.reset_parameters() | |
with torch.no_grad(): | |
if init_std is not None: | |
for linear in self.linears: | |
linear.weight.normal_(0, init_std) | |
linear.bias.normal_(0, init_std) | |
if sparseness > 0.0: | |
for linear in self.linears[1:-1]: | |
linear.weight /= (1. - sparseness) ** (1 / 2) | |
linear.weight *= torch.bernoulli(torch.ones_like(linear.weight) * (1. - sparseness)) | |
def forward(self, x): | |
for linear in self.linears[:-1]: | |
x = linear(x) | |
x = x + torch.randn_like(x) * self.preactivation_noise_std | |
x = torch.tanh(x) | |
x = self.linears[-1](x) | |
return x | |
def sample_input(input_sampling_setting, batch_size, seq_len, num_features, device=default_device): | |
if input_sampling_setting == 'normal': | |
x = torch.randn(batch_size, seq_len, num_features, device=device) | |
x_for_mlp = x | |
elif input_sampling_setting == 'uniform': | |
x = torch.rand(batch_size, seq_len, num_features, device=device) | |
x_for_mlp = (x - .5)/math.sqrt(1/12) | |
else: | |
raise ValueError(f"Unknown input_sampling: {input_sampling_setting}") | |
return x, x_for_mlp | |
def get_batch(batch_size, seq_len, num_features, hyperparameters, device=default_device, num_outputs=1, **kwargs): | |
if hyperparameters is None: | |
hyperparameters = { | |
'mlp_num_layers': 2, | |
'mlp_num_hidden': 64, | |
'mlp_init_std': 0.1, | |
'mlp_sparseness': 0.2, | |
'mlp_input_sampling': 'normal', | |
'mlp_output_noise': 0.0, | |
'mlp_noisy_targets': False, | |
'mlp_preactivation_noise_std': 0.0, | |
} | |
x, x_for_mlp = sample_input(hyperparameters.get('mlp_input_sampling', 'normal'), batch_size, seq_len, num_features, | |
device=device) | |
model = MLP(num_features, hyperparameters['mlp_num_layers'], hyperparameters['mlp_num_hidden'], | |
num_outputs, hyperparameters['mlp_init_std'], hyperparameters['mlp_sparseness'], | |
hyperparameters['mlp_preactivation_noise_std'], hyperparameters.get('activation', 'tanh')).to(device) | |
no_noise_model = MLP(num_features, hyperparameters['mlp_num_layers'], hyperparameters['mlp_num_hidden'], | |
num_outputs, hyperparameters['mlp_init_std'], hyperparameters['mlp_sparseness'], | |
0., hyperparameters.get('activation', 'tanh')).to(device) | |
ys = [] | |
targets = [] | |
for x_ in x_for_mlp: | |
model.reset_parameters() | |
y = model(x_ / math.sqrt(num_features)) | |
ys.append(y.unsqueeze(1)) | |
if not hyperparameters.get('mlp_preactivation_noise_in_targets', True): | |
assert not hyperparameters['mlp_noisy_targets'] | |
no_noise_model.load_state_dict(model.state_dict()) | |
target = no_noise_model(x_ / math.sqrt(num_features)) | |
targets.append(target.unsqueeze(1)) | |
y = torch.cat(ys, dim=1) | |
targets = torch.cat(targets, dim=1) if targets else y | |
noisy_y = y + torch.randn_like(y) * hyperparameters['mlp_output_noise'] | |
#return x.transpose(0, 1), noisy_y, (noisy_y if hyperparameters['mlp_noisy_targets'] else targets) | |
return Batch(x.transpose(0, 1), noisy_y, (noisy_y if hyperparameters['mlp_noisy_targets'] else targets)) | |
DataLoader = get_batch_to_dataloader(get_batch) | |