Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
def zero_module(module): | |
""" | |
Zero out the parameters of a module and return it. | |
""" | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
class StylizationBlock(nn.Module): | |
def __init__(self, latent_dim, time_embed_dim, dropout): | |
super().__init__() | |
self.emb_layers = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(time_embed_dim, 2 * latent_dim), | |
) | |
self.norm = nn.LayerNorm(latent_dim) | |
self.out_layers = nn.Sequential( | |
nn.SiLU(), | |
nn.Dropout(p=dropout), | |
zero_module(nn.Linear(latent_dim, latent_dim)), | |
) | |
def forward(self, h, emb): | |
""" | |
h: B, T, D | |
emb: B, D | |
""" | |
# B, 1, 2D | |
emb_out = self.emb_layers(emb).unsqueeze(1) | |
# scale: B, 1, D / shift: B, 1, D | |
scale, shift = torch.chunk(emb_out, 2, dim=2) | |
h = self.norm(h) * (1 + scale) + shift | |
h = self.out_layers(h) | |
return h | |