rosenyu's picture
Upload 529 files
165ee00 verified
import random
import math
import torch
from torch import nn
import time
import numpy as np
from ..utils import default_device
from .prior import Batch
from .utils import get_batch_to_dataloader
class MLP(torch.nn.Module):
def __init__(self, num_inputs, num_layers, num_hidden, num_outputs, init_std=None, sparseness=0.0,
preactivation_noise_std=0.0, activation='tanh'):
super(MLP, self).__init__()
self.linears = nn.ModuleList(
[nn.Linear(num_inputs, num_hidden)] + \
[nn.Linear(num_hidden,num_hidden) for _ in range(num_layers-2)] + \
[nn.Linear(num_hidden,num_outputs)]
)
self.init_std = init_std
self.sparseness = sparseness
self.reset_parameters()
self.preactivation_noise_std = preactivation_noise_std
self.activation = {
'tanh': torch.nn.Tanh(),
'relu': torch.nn.ReLU(),
'elu': torch.nn.ELU(),
'identity': torch.nn.Identity(),
}[activation]
def reset_parameters(self, init_std=None, sparseness=None):
init_std = init_std if init_std is not None else self.init_std
sparseness = sparseness if sparseness is not None else self.sparseness
for linear in self.linears:
linear.reset_parameters()
with torch.no_grad():
if init_std is not None:
for linear in self.linears:
linear.weight.normal_(0, init_std)
linear.bias.normal_(0, init_std)
if sparseness > 0.0:
for linear in self.linears[1:-1]:
linear.weight /= (1. - sparseness) ** (1 / 2)
linear.weight *= torch.bernoulli(torch.ones_like(linear.weight) * (1. - sparseness))
def forward(self, x):
for linear in self.linears[:-1]:
x = linear(x)
x = x + torch.randn_like(x) * self.preactivation_noise_std
x = torch.tanh(x)
x = self.linears[-1](x)
return x
def sample_input(input_sampling_setting, batch_size, seq_len, num_features, device=default_device):
if input_sampling_setting == 'normal':
x = torch.randn(batch_size, seq_len, num_features, device=device)
x_for_mlp = x
elif input_sampling_setting == 'uniform':
x = torch.rand(batch_size, seq_len, num_features, device=device)
x_for_mlp = (x - .5)/math.sqrt(1/12)
else:
raise ValueError(f"Unknown input_sampling: {input_sampling_setting}")
return x, x_for_mlp
@torch.no_grad()
def get_batch(batch_size, seq_len, num_features, hyperparameters, device=default_device, num_outputs=1, **kwargs):
if hyperparameters is None:
hyperparameters = {
'mlp_num_layers': 2,
'mlp_num_hidden': 64,
'mlp_init_std': 0.1,
'mlp_sparseness': 0.2,
'mlp_input_sampling': 'normal',
'mlp_output_noise': 0.0,
'mlp_noisy_targets': False,
'mlp_preactivation_noise_std': 0.0,
}
x, x_for_mlp = sample_input(hyperparameters.get('mlp_input_sampling', 'normal'), batch_size, seq_len, num_features,
device=device)
model = MLP(num_features, hyperparameters['mlp_num_layers'], hyperparameters['mlp_num_hidden'],
num_outputs, hyperparameters['mlp_init_std'], hyperparameters['mlp_sparseness'],
hyperparameters['mlp_preactivation_noise_std'], hyperparameters.get('activation', 'tanh')).to(device)
no_noise_model = MLP(num_features, hyperparameters['mlp_num_layers'], hyperparameters['mlp_num_hidden'],
num_outputs, hyperparameters['mlp_init_std'], hyperparameters['mlp_sparseness'],
0., hyperparameters.get('activation', 'tanh')).to(device)
ys = []
targets = []
for x_ in x_for_mlp:
model.reset_parameters()
y = model(x_ / math.sqrt(num_features))
ys.append(y.unsqueeze(1))
if not hyperparameters.get('mlp_preactivation_noise_in_targets', True):
assert not hyperparameters['mlp_noisy_targets']
no_noise_model.load_state_dict(model.state_dict())
target = no_noise_model(x_ / math.sqrt(num_features))
targets.append(target.unsqueeze(1))
y = torch.cat(ys, dim=1)
targets = torch.cat(targets, dim=1) if targets else y
noisy_y = y + torch.randn_like(y) * hyperparameters['mlp_output_noise']
#return x.transpose(0, 1), noisy_y, (noisy_y if hyperparameters['mlp_noisy_targets'] else targets)
return Batch(x.transpose(0, 1), noisy_y, (noisy_y if hyperparameters['mlp_noisy_targets'] else targets))
DataLoader = get_batch_to_dataloader(get_batch)