|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
import torch.nn.functional as F |
|
|
|
class LeCun(nn.Module): |
|
def __init__(self): |
|
super(LeCun, self).__init__() |
|
self.tanh = nn.Tanh() |
|
|
|
def forward(self, x): |
|
return 1.7159 * self.tanh(0.666 * x) |
|
|
|
|
|
class Period_encoder_decoder_predication(nn.Module): |
|
def __init__(self, seq_in_len, seq_out_len, enc_in, period_len): |
|
super(Period_encoder_decoder_predication, self).__init__() |
|
self.seq_len = seq_in_len |
|
self.pred_len = seq_out_len |
|
self.enc_in = enc_in |
|
self.period_len = period_len |
|
|
|
self.seg_num_x = self.seq_len // self.period_len |
|
self.seg_num_y = self.pred_len // self.period_len |
|
|
|
self.conv1d = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1 + 2 * (self.period_len // 2), |
|
stride=1, padding=self.period_len // 2, padding_mode="zeros", bias=False) |
|
|
|
self.linear1 = nn.Linear(self.seg_num_x, self.seg_num_y) |
|
self.linear2 = nn.Linear(self.period_len, self.period_len) |
|
|
|
def forward(self, x): |
|
|
|
batch_size = x.shape[0] |
|
x = x.permute(0, 2, 1) |
|
x = self.conv1d(x.reshape(-1, 1, self.seq_len)).reshape(-1, self.enc_in, self.seq_len) + x |
|
x = x.reshape(-1, self.seg_num_x, self.period_len).permute(0, 2, 1) |
|
y = self.linear1(x).permute(0, 2, 1) |
|
y = self.linear2(y).permute(0, 2, 1) |
|
y = y.permute(0, 2, 1).reshape(batch_size, self.enc_in, self.pred_len) |
|
|
|
return y |
|
|
|
|
|
class Multi_period_predication(nn.Module): |
|
def __init__(self, seq_in_len, seq_out_len, enc_in, period_channels, device='cuda'): |
|
super(Multi_period_predication, self).__init__() |
|
self.seq_len = seq_in_len |
|
self.pred_len = seq_out_len |
|
self.enc_in = enc_in |
|
self.period_channels = period_channels |
|
self.device = device |
|
self.period_encoders = nn.ModuleList( |
|
[ |
|
Period_encoder_decoder_predication( |
|
seq_in_len, seq_out_len, enc_in, period_len.item() if isinstance( |
|
period_len, torch.Tensor) else period_len) |
|
for period_len in period_channels |
|
]) |
|
self.period_weights = nn.Parameter(torch.ones(len(period_channels), 1, 1).to(device)) |
|
|
|
def forward(self, x): |
|
|
|
batch_size = x.shape[0] |
|
period_outputs = [] |
|
for i, period_encoder in enumerate(self.period_encoders): |
|
period_outputs.append(period_encoder(x)) |
|
period_outputs = torch.stack(period_outputs, dim=1) |
|
weighted_output = (period_outputs * self.period_weights).sum(dim=1) |
|
return weighted_output |
|
|
|
class Sequential_projection(nn.Module): |
|
def __init__(self, seq_in_len, seq_out_len): |
|
super(Sequential_projection, self).__init__() |
|
self.seq_len = seq_in_len |
|
self.pred_len = seq_out_len |
|
|
|
self.linear = nn.Linear(self.seq_len, self.pred_len, bias=False) |
|
|
|
def forward(self, x): |
|
|
|
|
|
y = self.linear(x) |
|
|
|
return y |
|
|
|
|
|
class CfcCell(nn.Module): |
|
def __init__(self, input_size, hidden_size, hparams): |
|
super(CfcCell, self).__init__() |
|
|
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.hparams = hparams |
|
self._no_gate = False |
|
if "no_gate" in self.hparams: |
|
self._no_gate = self.hparams["no_gate"] |
|
self._minimal = False |
|
if "minimal" in self.hparams: |
|
self._minimal = self.hparams["minimal"] |
|
|
|
if self.hparams["backbone_activation"] == "silu": |
|
backbone_activation = nn.SiLU |
|
elif self.hparams["backbone_activation"] == "relu": |
|
backbone_activation = nn.ReLU |
|
elif self.hparams["backbone_activation"] == "tanh": |
|
backbone_activation = nn.Tanh |
|
elif self.hparams["backbone_activation"] == "gelu": |
|
backbone_activation = nn.GELU |
|
elif self.hparams["backbone_activation"] == "lecun": |
|
backbone_activation = LeCun |
|
else: |
|
raise ValueError("Unknown activation") |
|
layer_list = [ |
|
nn.Linear(input_size + hidden_size, self.hparams["backbone_units"]), |
|
|
|
backbone_activation(), |
|
] |
|
for i in range(1, self.hparams["backbone_layers"]): |
|
layer_list.append( |
|
|
|
nn.Linear( |
|
self.hparams["backbone_units"], self.hparams["backbone_units"] |
|
) |
|
) |
|
layer_list.append(backbone_activation()) |
|
if "backbone_dr" in self.hparams.keys(): |
|
layer_list.append(torch.nn.Dropout(self.hparams["backbone_dr"])) |
|
self.backbone = nn.Sequential(*layer_list) |
|
self.tanh = nn.Tanh() |
|
self.sigmoid = nn.Sigmoid() |
|
self.ff1 = nn.Linear(self.hparams["backbone_units"], hidden_size) |
|
if self._minimal: |
|
self.w_tau = torch.nn.Parameter( |
|
data=torch.zeros(1, self.hidden_size), requires_grad=True |
|
) |
|
self.A = torch.nn.Parameter( |
|
data=torch.ones(1, self.hidden_size), requires_grad=True |
|
) |
|
else: |
|
self.ff2 = nn.Linear(self.hparams["backbone_units"], hidden_size) |
|
self.time_a = nn.Linear(self.hparams["backbone_units"], hidden_size) |
|
self.time_b = nn.Linear(self.hparams["backbone_units"], hidden_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.init_weights() |
|
|
|
def init_weights(self): |
|
init_gain = self.hparams.get("init") |
|
if init_gain is not None: |
|
for w in self.parameters(): |
|
if w.dim() == 2: |
|
torch.nn.init.xavier_uniform_(w, gain=init_gain) |
|
|
|
def forward(self, input, hx, ts): |
|
|
|
batch_size = input.size(0) |
|
ts = ts.view(batch_size, 1).unsqueeze(1) |
|
w_ts = torch.exp(ts * (1 - 2 * torch.log(ts))) |
|
x = torch.cat([input, hx], 2) |
|
x = self.backbone(x) |
|
if self._minimal: |
|
|
|
ff1 = self.ff1(x) |
|
new_hidden = ( |
|
-self.A |
|
* torch.exp(-ts * (torch.abs(self.w_tau) + torch.abs(ff1))) |
|
* ff1 |
|
+ self.A |
|
) |
|
else: |
|
|
|
ff1 = self.tanh(self.ff1(x)) |
|
ff2 = self.tanh(self.ff2(x)) |
|
t_a = self.time_a(x) |
|
t_b = self.time_b(x) |
|
t_interp = self.sigmoid(t_a * w_ts + t_b) |
|
if self._no_gate: |
|
new_hidden = ff1 + t_interp * ff2 |
|
else: |
|
new_hidden = ff1 * (1.0 - t_interp) + t_interp * ff2 |
|
return new_hidden |
|
|
|
|
|
|
|
|
|
class Cfc(nn.Module): |
|
def __init__( |
|
self, |
|
in_features, |
|
hidden_size, |
|
out_feature, |
|
hparams, |
|
return_sequences=False, |
|
use_mixed=False, |
|
use_ltc=False, |
|
): |
|
super(Cfc, self).__init__() |
|
self.in_features = in_features |
|
self.hidden_size = hidden_size |
|
self.out_feature = out_feature |
|
self.return_sequences = return_sequences |
|
self.in_lens = hparams['in_lens'] |
|
self.out_lens = hparams['out_lens'] |
|
self.period_lens = torch.tensor(hparams['period_len'], dtype=torch.int32) |
|
|
|
if use_ltc: |
|
self.rnn_cell = LTCCell(in_features, hidden_size) |
|
|
|
else: |
|
self.rnn_cell_forward = CfcCell(in_features, hidden_size, hparams) |
|
self.use_mixed = use_mixed |
|
self.fc = nn.Linear(self.hidden_size, self.out_feature) |
|
self.encoder = Multi_period_predication(self.in_lens, self.in_lens, self.in_features, self.period_lens) |
|
self.x_pro = Multi_period_predication(self.in_lens, self.in_lens, self.in_features, self.period_lens) |
|
self.Sequential_projection = Sequential_projection(self.in_lens * 2, self.out_lens) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, timespans=None, mask=None): |
|
|
|
x = x.squeeze(1).permute(0, 2, 1) |
|
|
|
|
|
seq_mean = torch.mean(x, dim=1).unsqueeze(2).permute(0, 2, 1) |
|
x = x - seq_mean |
|
seq_mean_out = seq_mean[:, :, 0:3] |
|
|
|
x_php = self.encoder(x).permute(0, 2, 1) |
|
x = x_php |
|
device = x.device |
|
batch_size = x.size(0) |
|
seq_len = x.size(1) |
|
true_in_features = x.size(2) |
|
module_len = 24 |
|
module_num = 7 |
|
x = x.permute(0, 2, 1).reshape(batch_size, true_in_features, module_num, module_len).permute(0, 3, 2, 1) |
|
timespans_forward = torch.ones((batch_size, module_len)).to(device) |
|
h_state_forward = torch.zeros((batch_size, module_num, self.hidden_size), device=device) |
|
output_sequence_forward = [] |
|
for t in range(module_len): |
|
inputs_forward = x[:, t] |
|
ts_forward = timespans_forward[:, t].squeeze() |
|
h_state_forward = self.rnn_cell_forward.forward(inputs_forward, h_state_forward, ts_forward) |
|
if self.return_sequences: |
|
output_sequence_forward.append(h_state_forward) |
|
if self.return_sequences: |
|
readout = torch.stack(output_sequence_forward, dim=1) |
|
else: |
|
readout = self.fc(h_state_forward) |
|
|
|
timespans_forward_new = torch.zeros((batch_size, module_len)).to(device) |
|
timespans_forward_new += (1 / 168) |
|
h_state_forward_new = torch.zeros((batch_size, module_num, self.hidden_size), device=device) |
|
output_sequence_forward_new = [] |
|
|
|
x_pro = self.x_pro(x_php).permute(0, 2, 1) |
|
x_pro = x_pro.permute(0, 2, 1).reshape(batch_size, true_in_features, 7, module_len).permute(0, 3, 2, 1) |
|
for t in range(module_len): |
|
inputs_forward_new = x_pro[:, t] |
|
ts_forward_new = timespans_forward_new[:, t].squeeze() |
|
h_state_forward_new = self.rnn_cell_forward.forward(inputs_forward_new, h_state_forward_new, ts_forward_new) |
|
if self.return_sequences: |
|
output_sequence_forward_new.append(h_state_forward_new) |
|
if self.return_sequences: |
|
readout_new = torch.stack(output_sequence_forward_new, dim=1) |
|
readout = torch.cat((readout, readout_new), dim=2) |
|
readout = readout.permute(0, 2, 1, 3).reshape(batch_size, module_num*module_len*2, self.hidden_size) |
|
readout = self.Sequential_projection(readout.permute(0, 2, 1)).permute(0, 2, 1) |
|
readout = self.fc(readout) |
|
|
|
readout = readout + seq_mean_out |
|
return readout |
|
|
|
|
|
class LTCCell(nn.Module): |
|
def __init__( |
|
self, |
|
in_features, |
|
units, |
|
ode_unfolds=2, |
|
epsilon=1e-8, |
|
): |
|
super(LTCCell, self).__init__() |
|
self.in_features = in_features |
|
self.units = units |
|
self._init_ranges = { |
|
"gleak": (0.001, 1.0), |
|
"vleak": (-0.2, 0.2), |
|
"cm": (0.4, 0.6), |
|
"w": (0.001, 1.0), |
|
"sigma": (3, 8), |
|
"mu": (0.3, 0.8), |
|
"sensory_w": (0.001, 1.0), |
|
"sensory_sigma": (3, 8), |
|
"sensory_mu": (0.3, 0.8), |
|
} |
|
self._ode_unfolds = ode_unfolds |
|
self._epsilon = epsilon |
|
|
|
self.softplus = nn.Identity() |
|
self._allocate_parameters() |
|
|
|
@property |
|
def state_size(self): |
|
return self.units |
|
|
|
@property |
|
def sensory_size(self): |
|
return self.in_features |
|
|
|
def add_weight(self, name, init_value): |
|
param = torch.nn.Parameter(init_value) |
|
self.register_parameter(name, param) |
|
return param |
|
|
|
def _get_init_value(self, shape, param_name): |
|
minval, maxval = self._init_ranges[param_name] |
|
if minval == maxval: |
|
return torch.ones(shape) * minval |
|
else: |
|
return torch.rand(*shape) * (maxval - minval) + minval |
|
|
|
def _erev_initializer(self, shape=None): |
|
return np.random.default_rng().choice([-1, 1], size=shape) |
|
|
|
def _allocate_parameters(self): |
|
self._params = {} |
|
self._params["gleak"] = self.add_weight( |
|
name="gleak", init_value=self._get_init_value((self.state_size,), "gleak") |
|
) |
|
self._params["vleak"] = self.add_weight( |
|
name="vleak", init_value=self._get_init_value((self.state_size,), "vleak") |
|
) |
|
self._params["cm"] = self.add_weight( |
|
name="cm", init_value=self._get_init_value((self.state_size,), "cm") |
|
) |
|
self._params["sigma"] = self.add_weight( |
|
name="sigma", |
|
init_value=self._get_init_value( |
|
(self.state_size, self.state_size), "sigma" |
|
), |
|
) |
|
self._params["mu"] = self.add_weight( |
|
name="mu", |
|
init_value=self._get_init_value((self.state_size, self.state_size), "mu"), |
|
) |
|
self._params["w"] = self.add_weight( |
|
name="w", |
|
init_value=self._get_init_value((self.state_size, self.state_size), "w"), |
|
) |
|
self._params["erev"] = self.add_weight( |
|
name="erev", |
|
init_value=torch.Tensor( |
|
self._erev_initializer((self.state_size, self.state_size)) |
|
), |
|
) |
|
self._params["sensory_sigma"] = self.add_weight( |
|
name="sensory_sigma", |
|
init_value=self._get_init_value( |
|
(self.sensory_size, self.state_size), "sensory_sigma" |
|
), |
|
) |
|
self._params["sensory_mu"] = self.add_weight( |
|
name="sensory_mu", |
|
init_value=self._get_init_value( |
|
(self.sensory_size, self.state_size), "sensory_mu" |
|
), |
|
) |
|
self._params["sensory_w"] = self.add_weight( |
|
name="sensory_w", |
|
init_value=self._get_init_value( |
|
(self.sensory_size, self.state_size), "sensory_w" |
|
), |
|
) |
|
self._params["sensory_erev"] = self.add_weight( |
|
name="sensory_erev", |
|
init_value=torch.Tensor( |
|
self._erev_initializer((self.sensory_size, self.state_size)) |
|
), |
|
) |
|
|
|
self._params["input_w"] = self.add_weight( |
|
name="input_w", |
|
init_value=torch.ones((self.sensory_size,)), |
|
) |
|
self._params["input_b"] = self.add_weight( |
|
name="input_b", |
|
init_value=torch.zeros((self.sensory_size,)), |
|
) |
|
|
|
|
|
def _sigmoid(self, v_pre, mu, sigma): |
|
v_pre = torch.unsqueeze(v_pre, -1) |
|
mues = v_pre - mu |
|
x = sigma * mues |
|
return torch.sigmoid(x) |
|
|
|
def _ode_solver(self, inputs, state, elapsed_time): |
|
v_pre = state |
|
|
|
|
|
sensory_w_activation = self.softplus(self._params["sensory_w"]) * self._sigmoid( |
|
inputs, self._params["sensory_mu"], self._params["sensory_sigma"] |
|
) |
|
|
|
sensory_rev_activation = sensory_w_activation * self._params["sensory_erev"] |
|
|
|
|
|
w_numerator_sensory = torch.sum(sensory_rev_activation, dim=1) |
|
w_denominator_sensory = torch.sum(sensory_w_activation, dim=1) |
|
|
|
|
|
cm_t = self.softplus(self._params["cm"]).view(1, -1) / ( |
|
(elapsed_time + 1) / self._ode_unfolds |
|
) |
|
|
|
|
|
for t in range(self._ode_unfolds): |
|
w_activation = self.softplus(self._params["w"]) * self._sigmoid( |
|
v_pre, self._params["mu"], self._params["sigma"] |
|
) |
|
|
|
rev_activation = w_activation * self._params["erev"] |
|
|
|
|
|
w_numerator = torch.sum(rev_activation, dim=1) + w_numerator_sensory |
|
w_denominator = torch.sum(w_activation, dim=1) + w_denominator_sensory |
|
|
|
numerator = ( |
|
cm_t * v_pre |
|
+ self.softplus(self._params["gleak"]) * self._params["vleak"] |
|
+ w_numerator |
|
) |
|
denominator = cm_t + self.softplus(self._params["gleak"]) + w_denominator |
|
|
|
|
|
v_pre = numerator / (denominator + self._epsilon) |
|
if torch.any(torch.isnan(v_pre)): |
|
breakpoint() |
|
return v_pre |
|
|
|
def _map_inputs(self, inputs): |
|
inputs = inputs * self._params["input_w"] |
|
inputs = inputs + self._params["input_b"] |
|
return inputs |
|
|
|
def _map_outputs(self, state): |
|
output = state |
|
output = output * self._params["output_w"] |
|
output = output + self._params["output_b"] |
|
return output |
|
|
|
def _clip(self, w): |
|
return torch.nn.ReLU()(w) |
|
|
|
def apply_weight_constraints(self): |
|
self._params["w"].data = self._clip(self._params["w"].data) |
|
self._params["sensory_w"].data = self._clip(self._params["sensory_w"].data) |
|
self._params["cm"].data = self._clip(self._params["cm"].data) |
|
self._params["gleak"].data = self._clip(self._params["gleak"].data) |
|
|
|
def forward(self, input, hx, ts): |
|
|
|
ts = ts.view((-1, 1)) |
|
inputs = self._map_inputs(input) |
|
|
|
next_state = self._ode_solver(inputs, hx, ts) |
|
|
|
|
|
|
|
return next_state |
|
|
|
|
|
|
|
|