DMP-PCFC / torch_cfc.py
XingyuLiang's picture
Upload 75 files
cd1df48 verified
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
class LeCun(nn.Module):
def __init__(self):
super(LeCun, self).__init__()
self.tanh = nn.Tanh()
def forward(self, x):
return 1.7159 * self.tanh(0.666 * x)
class Period_encoder_decoder_predication(nn.Module):
def __init__(self, seq_in_len, seq_out_len, enc_in, period_len):
super(Period_encoder_decoder_predication, self).__init__()
self.seq_len = seq_in_len
self.pred_len = seq_out_len
self.enc_in = enc_in
self.period_len = period_len
self.seg_num_x = self.seq_len // self.period_len
self.seg_num_y = self.pred_len // self.period_len
self.conv1d = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=1 + 2 * (self.period_len // 2),
stride=1, padding=self.period_len // 2, padding_mode="zeros", bias=False)
self.linear1 = nn.Linear(self.seg_num_x, self.seg_num_y)
self.linear2 = nn.Linear(self.period_len, self.period_len)
def forward(self, x):
# b, c, s
batch_size = x.shape[0]
x = x.permute(0, 2, 1)
x = self.conv1d(x.reshape(-1, 1, self.seq_len)).reshape(-1, self.enc_in, self.seq_len) + x
x = x.reshape(-1, self.seg_num_x, self.period_len).permute(0, 2, 1)
y = self.linear1(x).permute(0, 2, 1)
y = self.linear2(y).permute(0, 2, 1)
y = y.permute(0, 2, 1).reshape(batch_size, self.enc_in, self.pred_len)
# b, c, s
return y
class Multi_period_predication(nn.Module):
def __init__(self, seq_in_len, seq_out_len, enc_in, period_channels, device='cuda'):
super(Multi_period_predication, self).__init__()
self.seq_len = seq_in_len
self.pred_len = seq_out_len
self.enc_in = enc_in
self.period_channels = period_channels
self.device = device
self.period_encoders = nn.ModuleList(
[
Period_encoder_decoder_predication(
seq_in_len, seq_out_len, enc_in, period_len.item() if isinstance(
period_len, torch.Tensor) else period_len)
for period_len in period_channels
])
self.period_weights = nn.Parameter(torch.ones(len(period_channels), 1, 1).to(device))
def forward(self, x):
# x: [batch_size, enc_in, seq_len]
batch_size = x.shape[0]
period_outputs = []
for i, period_encoder in enumerate(self.period_encoders):
period_outputs.append(period_encoder(x))
period_outputs = torch.stack(period_outputs, dim=1) # [batch_size, num_periods, enc_in, pred_len]
weighted_output = (period_outputs * self.period_weights).sum(dim=1) # [batch_size, enc_in, pred_len]
return weighted_output
class Sequential_projection(nn.Module):
def __init__(self, seq_in_len, seq_out_len):
super(Sequential_projection, self).__init__()
self.seq_len = seq_in_len
self.pred_len = seq_out_len
# self.kan = KAN([self.seq_len, self.pred_len], grid_size=5, base_activation=nn.Identity)
self.linear = nn.Linear(self.seq_len, self.pred_len, bias=False)
def forward(self, x):
# b, n, s
# y = self.kan(x)
y = self.linear(x)
# b, n, s
return y
class CfcCell(nn.Module):
def __init__(self, input_size, hidden_size, hparams):
super(CfcCell, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.hparams = hparams
self._no_gate = False
if "no_gate" in self.hparams:
self._no_gate = self.hparams["no_gate"]
self._minimal = False
if "minimal" in self.hparams:
self._minimal = self.hparams["minimal"]
if self.hparams["backbone_activation"] == "silu":
backbone_activation = nn.SiLU
elif self.hparams["backbone_activation"] == "relu":
backbone_activation = nn.ReLU
elif self.hparams["backbone_activation"] == "tanh":
backbone_activation = nn.Tanh
elif self.hparams["backbone_activation"] == "gelu":
backbone_activation = nn.GELU
elif self.hparams["backbone_activation"] == "lecun":
backbone_activation = LeCun
else:
raise ValueError("Unknown activation")
layer_list = [
nn.Linear(input_size + hidden_size, self.hparams["backbone_units"]),
# KAN([input_size + hidden_size,self.hparams["backbone_units"]], grid_size=5, base_activation=nn.Identity),
backbone_activation(),
]
for i in range(1, self.hparams["backbone_layers"]):
layer_list.append(
# KAN([self.hparams["backbone_units"], self.hparams["backbone_units"]],grid_size=5, base_activation=nn.Identity)
nn.Linear(
self.hparams["backbone_units"], self.hparams["backbone_units"]
)
)
layer_list.append(backbone_activation())
if "backbone_dr" in self.hparams.keys():
layer_list.append(torch.nn.Dropout(self.hparams["backbone_dr"]))
self.backbone = nn.Sequential(*layer_list)
self.tanh = nn.Tanh()
self.sigmoid = nn.Sigmoid()
self.ff1 = nn.Linear(self.hparams["backbone_units"], hidden_size)
if self._minimal:
self.w_tau = torch.nn.Parameter(
data=torch.zeros(1, self.hidden_size), requires_grad=True
)
self.A = torch.nn.Parameter(
data=torch.ones(1, self.hidden_size), requires_grad=True
)
else:
self.ff2 = nn.Linear(self.hparams["backbone_units"], hidden_size)
self.time_a = nn.Linear(self.hparams["backbone_units"], hidden_size)
self.time_b = nn.Linear(self.hparams["backbone_units"], hidden_size)
#exp(ts-ts*ln(ts)) = exp(ts * (1 - ln(ts) ) ) 分子
# ts 分母
# self.attn_num = self.hparams["backbone_units"]
#
# self.query = torch.nn.Linear(self.attn_num, self.attn_num)
# self.key = torch.nn.Linear(self.attn_num, self.attn_num)
# self.value = torch.nn.Linear(self.attn_num, self.attn_num)
self.init_weights()
def init_weights(self):
init_gain = self.hparams.get("init")
if init_gain is not None:
for w in self.parameters():
if w.dim() == 2:
torch.nn.init.xavier_uniform_(w, gain=init_gain)
def forward(self, input, hx, ts):
batch_size = input.size(0)
ts = ts.view(batch_size, 1).unsqueeze(1)
w_ts = torch.exp(ts * (1 - 2 * torch.log(ts)))
x = torch.cat([input, hx], 2)
x = self.backbone(x)
if self._minimal:
# Solution
ff1 = self.ff1(x)
new_hidden = (
-self.A
* torch.exp(-ts * (torch.abs(self.w_tau) + torch.abs(ff1)))
* ff1
+ self.A
)
else:
# Cfc
ff1 = self.tanh(self.ff1(x)) # g
ff2 = self.tanh(self.ff2(x)) # h
t_a = self.time_a(x)
t_b = self.time_b(x)
t_interp = self.sigmoid(t_a * w_ts + t_b)
if self._no_gate:
new_hidden = ff1 + t_interp * ff2
else:
new_hidden = ff1 * (1.0 - t_interp) + t_interp * ff2
return new_hidden
class Cfc(nn.Module):
def __init__(
self,
in_features,
hidden_size,
out_feature,
hparams,
return_sequences=False,
use_mixed=False,
use_ltc=False,
):
super(Cfc, self).__init__()
self.in_features = in_features
self.hidden_size = hidden_size
self.out_feature = out_feature
self.return_sequences = return_sequences
self.in_lens = hparams['in_lens']
self.out_lens = hparams['out_lens']
self.period_lens = torch.tensor(hparams['period_len'], dtype=torch.int32)
if use_ltc:
self.rnn_cell = LTCCell(in_features, hidden_size)
else:
self.rnn_cell_forward = CfcCell(in_features, hidden_size, hparams)
self.use_mixed = use_mixed
self.fc = nn.Linear(self.hidden_size, self.out_feature)
self.encoder = Multi_period_predication(self.in_lens, self.in_lens, self.in_features, self.period_lens)
self.x_pro = Multi_period_predication(self.in_lens, self.in_lens, self.in_features, self.period_lens)
self.Sequential_projection = Sequential_projection(self.in_lens * 2, self.out_lens)
#2024年11月1日10点07分
# KANConv
# paralle CFC
# 预分解 不可取
def forward(self, x, timespans=None, mask=None):
# 预处理
x = x.squeeze(1).permute(0, 2, 1) # b t n
# 分布正变换
seq_mean = torch.mean(x, dim=1).unsqueeze(2).permute(0, 2, 1) # b, 1, n
x = x - seq_mean # b t n
seq_mean_out = seq_mean[:, :, 0:3]
# 交叉编码投影
x_php = self.encoder(x).permute(0, 2, 1) # b, t, n
x = x_php
device = x.device
batch_size = x.size(0)
seq_len = x.size(1)
true_in_features = x.size(2)
module_len = 24
module_num = 7
x = x.permute(0, 2, 1).reshape(batch_size, true_in_features, module_num, module_len).permute(0, 3, 2, 1)
timespans_forward = torch.ones((batch_size, module_len)).to(device)
h_state_forward = torch.zeros((batch_size, module_num, self.hidden_size), device=device)
output_sequence_forward = []
for t in range(module_len):
inputs_forward = x[:, t]
ts_forward = timespans_forward[:, t].squeeze()
h_state_forward = self.rnn_cell_forward.forward(inputs_forward, h_state_forward, ts_forward)
if self.return_sequences:
output_sequence_forward.append(h_state_forward)
if self.return_sequences:
readout = torch.stack(output_sequence_forward, dim=1)
else:
readout = self.fc(h_state_forward)
timespans_forward_new = torch.zeros((batch_size, module_len)).to(device)#.cumsum(dim=1)
timespans_forward_new += (1 / 168) # 168 24 12
h_state_forward_new = torch.zeros((batch_size, module_num, self.hidden_size), device=device)
output_sequence_forward_new = []
x_pro = self.x_pro(x_php).permute(0, 2, 1)
x_pro = x_pro.permute(0, 2, 1).reshape(batch_size, true_in_features, 7, module_len).permute(0, 3, 2, 1)
for t in range(module_len):
inputs_forward_new = x_pro[:, t]
ts_forward_new = timespans_forward_new[:, t].squeeze()
h_state_forward_new = self.rnn_cell_forward.forward(inputs_forward_new, h_state_forward_new, ts_forward_new)
if self.return_sequences:
output_sequence_forward_new.append(h_state_forward_new)
if self.return_sequences:
readout_new = torch.stack(output_sequence_forward_new, dim=1)
readout = torch.cat((readout, readout_new), dim=2)
readout = readout.permute(0, 2, 1, 3).reshape(batch_size, module_num*module_len*2, self.hidden_size)
readout = self.Sequential_projection(readout.permute(0, 2, 1)).permute(0, 2, 1)
readout = self.fc(readout)
# 分布逆变换
readout = readout + seq_mean_out #+ x_php_res #+ x_pdp.permute(0, 2, 1)
return readout
class LTCCell(nn.Module):
def __init__(
self,
in_features,
units,
ode_unfolds=2,
epsilon=1e-8,
):
super(LTCCell, self).__init__()
self.in_features = in_features
self.units = units
self._init_ranges = {
"gleak": (0.001, 1.0),
"vleak": (-0.2, 0.2),
"cm": (0.4, 0.6),
"w": (0.001, 1.0),
"sigma": (3, 8),
"mu": (0.3, 0.8),
"sensory_w": (0.001, 1.0),
"sensory_sigma": (3, 8),
"sensory_mu": (0.3, 0.8),
}
self._ode_unfolds = ode_unfolds
self._epsilon = epsilon
# self.softplus = nn.Softplus()
self.softplus = nn.Identity()
self._allocate_parameters()
@property
def state_size(self):
return self.units
@property
def sensory_size(self):
return self.in_features
def add_weight(self, name, init_value):
param = torch.nn.Parameter(init_value)
self.register_parameter(name, param)
return param
def _get_init_value(self, shape, param_name):
minval, maxval = self._init_ranges[param_name]
if minval == maxval:
return torch.ones(shape) * minval
else:
return torch.rand(*shape) * (maxval - minval) + minval
def _erev_initializer(self, shape=None):
return np.random.default_rng().choice([-1, 1], size=shape)
def _allocate_parameters(self):
self._params = {}
self._params["gleak"] = self.add_weight(
name="gleak", init_value=self._get_init_value((self.state_size,), "gleak")
)
self._params["vleak"] = self.add_weight(
name="vleak", init_value=self._get_init_value((self.state_size,), "vleak")
)
self._params["cm"] = self.add_weight(
name="cm", init_value=self._get_init_value((self.state_size,), "cm")
)
self._params["sigma"] = self.add_weight(
name="sigma",
init_value=self._get_init_value(
(self.state_size, self.state_size), "sigma"
),
)
self._params["mu"] = self.add_weight(
name="mu",
init_value=self._get_init_value((self.state_size, self.state_size), "mu"),
)
self._params["w"] = self.add_weight(
name="w",
init_value=self._get_init_value((self.state_size, self.state_size), "w"),
)
self._params["erev"] = self.add_weight(
name="erev",
init_value=torch.Tensor(
self._erev_initializer((self.state_size, self.state_size))
),
)
self._params["sensory_sigma"] = self.add_weight(
name="sensory_sigma",
init_value=self._get_init_value(
(self.sensory_size, self.state_size), "sensory_sigma"
),
)
self._params["sensory_mu"] = self.add_weight(
name="sensory_mu",
init_value=self._get_init_value(
(self.sensory_size, self.state_size), "sensory_mu"
),
)
self._params["sensory_w"] = self.add_weight(
name="sensory_w",
init_value=self._get_init_value(
(self.sensory_size, self.state_size), "sensory_w"
),
)
self._params["sensory_erev"] = self.add_weight(
name="sensory_erev",
init_value=torch.Tensor(
self._erev_initializer((self.sensory_size, self.state_size))
),
)
self._params["input_w"] = self.add_weight(
name="input_w",
init_value=torch.ones((self.sensory_size,)),
)
self._params["input_b"] = self.add_weight(
name="input_b",
init_value=torch.zeros((self.sensory_size,)),
)
def _sigmoid(self, v_pre, mu, sigma):
v_pre = torch.unsqueeze(v_pre, -1) # For broadcasting
mues = v_pre - mu
x = sigma * mues
return torch.sigmoid(x)
def _ode_solver(self, inputs, state, elapsed_time):
v_pre = state
# We can pre-compute the effects of the sensory neurons here
sensory_w_activation = self.softplus(self._params["sensory_w"]) * self._sigmoid(
inputs, self._params["sensory_mu"], self._params["sensory_sigma"]
)
sensory_rev_activation = sensory_w_activation * self._params["sensory_erev"]
# Reduce over dimension 1 (=source sensory neurons)
w_numerator_sensory = torch.sum(sensory_rev_activation, dim=1)
w_denominator_sensory = torch.sum(sensory_w_activation, dim=1)
# cm/t is loop invariant
cm_t = self.softplus(self._params["cm"]).view(1, -1) / (
(elapsed_time + 1) / self._ode_unfolds
)
# Unfold the multiply ODE multiple times into one RNN step
for t in range(self._ode_unfolds):
w_activation = self.softplus(self._params["w"]) * self._sigmoid(
v_pre, self._params["mu"], self._params["sigma"]
)
rev_activation = w_activation * self._params["erev"]
# Reduce over dimension 1 (=source neurons)
w_numerator = torch.sum(rev_activation, dim=1) + w_numerator_sensory
w_denominator = torch.sum(w_activation, dim=1) + w_denominator_sensory
numerator = (
cm_t * v_pre
+ self.softplus(self._params["gleak"]) * self._params["vleak"]
+ w_numerator
)
denominator = cm_t + self.softplus(self._params["gleak"]) + w_denominator
# Avoid dividing by 0
v_pre = numerator / (denominator + self._epsilon)
if torch.any(torch.isnan(v_pre)):
breakpoint()
return v_pre
def _map_inputs(self, inputs):
inputs = inputs * self._params["input_w"]
inputs = inputs + self._params["input_b"]
return inputs
def _map_outputs(self, state):
output = state
output = output * self._params["output_w"]
output = output + self._params["output_b"]
return output
def _clip(self, w):
return torch.nn.ReLU()(w)
def apply_weight_constraints(self):
self._params["w"].data = self._clip(self._params["w"].data)
self._params["sensory_w"].data = self._clip(self._params["sensory_w"].data)
self._params["cm"].data = self._clip(self._params["cm"].data)
self._params["gleak"].data = self._clip(self._params["gleak"].data)
def forward(self, input, hx, ts):
# Regularly sampled mode (elapsed time = 1 second)
ts = ts.view((-1, 1))
inputs = self._map_inputs(input)
next_state = self._ode_solver(inputs, hx, ts)
# outputs = self._map_outputs(next_state)
return next_state