Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import math | |
from timm.models.vision_transformer import PatchEmbed | |
from huggingface_hub import PyTorchModelHubMixin | |
class TimestepEmbedder(nn.Module): | |
"""Module to create timestep's embedding.""" | |
def __init__(self,hidden_size,frequency_embedding_size=256): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(frequency_embedding_size,hidden_size), | |
nn.SiLU(), | |
nn.Linear(hidden_size,hidden_size) | |
) | |
self.frequency_embedding_size = frequency_embedding_size | |
def forward(self, t): | |
half = self.frequency_embedding_size // 2 | |
freqs = torch.exp( | |
-math.log(10000) * torch.arange(start=0,end=half) / half | |
).to(device=t.device) | |
args = torch.einsum('i,j->ij', t, freqs.to(t.device)) | |
freqs = torch.cat([torch.cos(args),torch.sin(args)],dim=-1) | |
mlp_dtype = next(self.mlp.parameters()).dtype | |
freqs_casted = freqs.to(mlp_dtype) | |
return self.mlp(freqs_casted) | |
class ViTAttn(nn.Module): | |
def __init__(self,hidden_size,num_heads): | |
super().__init__() | |
self.attn = nn.MultiheadAttention(hidden_size,num_heads,bias=True,add_bias_kv=True,batch_first=True) | |
def forward(self,x): | |
attn, _ = self.attn(x,x,x) | |
return attn | |
class DiTBlock(nn.Module): | |
""" | |
DiT Block with adaptive layer norm zero (adaLN-Zero) conditioning. | |
Using post-norm | |
""" | |
def __init__(self,hidden_size,num_heads): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(hidden_size,elementwise_affine=False,eps=1e-6) | |
self.attn = ViTAttn(hidden_size,num_heads) | |
self.norm2 = nn.LayerNorm(hidden_size,elementwise_affine=False,eps=1e-6) | |
self.mlp = nn.Sequential( | |
nn.Linear(hidden_size,4*hidden_size), | |
nn.GELU(approximate="tanh"), | |
nn.Linear(4*hidden_size,hidden_size) | |
) | |
self.adaLN = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(hidden_size,6*hidden_size) | |
) | |
def forward(self,x,c): | |
gamma_1,beta_1,alpha_1,gamma_2,beta_2,alpha_2 = self.adaLN(c).chunk(6,dim=1) | |
x = self.norm1(x + alpha_1.unsqueeze(1) * self.attn(x)) | |
x = x * (1+gamma_1.unsqueeze(1)) + beta_1.unsqueeze(1) | |
x = self.norm2(x + alpha_2.unsqueeze(1) * self.mlp(x)) | |
x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1) | |
return x | |
class DiT(nn.Module, | |
PyTorchModelHubMixin): | |
def __init__(self, | |
num_blocks=10, | |
hidden_size=640, | |
num_heads=10, | |
patch_size=2, | |
num_channels=4, | |
img_size=32, | |
num_genres=42, | |
num_styles=137): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.patch_size = patch_size | |
self.num_channels = num_channels | |
self.seq_len = (img_size // patch_size)**2 | |
self.img_size = img_size | |
self.blocks = nn.ModuleList( | |
DiTBlock(hidden_size,num_heads) for _ in range(num_blocks) | |
) | |
self.timestep_embed = TimestepEmbedder(hidden_size) | |
self.num_genres = num_genres | |
self.num_styles = num_styles | |
self.genre_condition = nn.Embedding(num_genres+1,hidden_size) # +1 for null condition | |
self.style_condition = nn.Embedding(num_styles+1,hidden_size) | |
self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, hidden_size)) | |
patch_dim = num_channels * patch_size * patch_size | |
self.proj_in = nn.Linear(patch_dim,hidden_size) | |
self.proj_out = nn.Linear(hidden_size,patch_dim) | |
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.adaLN_final = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(hidden_size, 2*hidden_size) | |
) | |
self.initialize_weights() | |
def initialize_weights(self): | |
nn.init.normal_(self.pos_embed, std=0.02) | |
nn.init.normal_(self.proj_out.weight, std=0.02) | |
nn.init.zeros_(self.proj_out.bias) | |
nn.init.normal_(self.proj_in.weight, std=0.02) | |
nn.init.zeros_(self.proj_in.bias) | |
nn.init.normal_(self.timestep_embed.mlp[0].weight, std=0.02) | |
nn.init.zeros_(self.timestep_embed.mlp[0].bias) | |
nn.init.normal_(self.timestep_embed.mlp[2].weight, std=0.02) | |
nn.init.zeros_(self.timestep_embed.mlp[2].bias) | |
for block in self.blocks: | |
nn.init.zeros_(block.adaLN[-1].weight) | |
nn.init.zeros_(block.adaLN[-1].bias) | |
nn.init.zeros_(self.adaLN_final[-1].weight) | |
nn.init.zeros_(self.adaLN_final[-1].bias) | |
nn.init.normal_(self.genre_condition.weight, std=0.02) | |
nn.init.normal_(self.style_condition.weight, std=0.02) | |
def patchify(self,z): | |
""" | |
from (batch_size,6,32,32) -> (batch_size,256,24) -> (batch_size,256,hidden_size) | |
""" | |
b,_,_,_ = z.shape | |
c = self.num_channels | |
p = self.patch_size | |
z = z.unfold(2,p,p).unfold(3,p,p) # (b,c,h//p,p,w//p,p) | |
z = z.contiguous().view(b,c,-1,p,p) # (b,c,hw//p**2,p,p) | |
z = torch.einsum('bcapq->bacpq',z).contiguous().view(b,-1,c*p**2) # (b,hw//p**2,c*p**2) | |
return self.proj_in(z) # (b,hw//p**2,hidden_size) | |
def unpatchify(self,z): | |
""" | |
from (batch_size,256,hidden_size) -> (batch_size,256,24) -> (batch_size,6,32,32) | |
""" | |
b,_,_ = z.shape | |
c = self.num_channels | |
p = self.patch_size | |
s = int(self.seq_len ** 0.5) | |
i = self.img_size | |
z = self.proj_out(z) # (b,hw//p**2,c*p**2) | |
z = z.view(b,s,s,c,p,p) # (b,h/p,w/p,c,p,p) | |
z = torch.einsum('befcpq->bcepfq',z) # (b,c,h/p,p,w/p,p) | |
z = z.contiguous().view(b,c,i,i) | |
return z | |
def forward(self,z,t,g,s): | |
t_embed = self.timestep_embed(t) # t_embed: (batch_size, hidden_size) | |
g_embed = self.genre_condition(g) | |
s_embed = self.style_condition(s) | |
c = t_embed + g_embed + s_embed | |
z = self.patchify(z) | |
z = z + self.pos_embed | |
for block in self.blocks: | |
z = block(z,c) | |
gamma, beta = self.adaLN_final(c).chunk(2,dim=-1) | |
z = self.norm_out(z) | |
z = z * (1+gamma.unsqueeze(1)) + beta.unsqueeze(1) | |
return self.unpatchify(z) | |
if __name__ == "__main__": | |
model = DiT(1,768,12,2,6,32) | |
z = torch.randn(2,6,32,32) | |
c = torch.randn(2,768) | |
t = torch.randint(0,1000,(2,)) | |
output = model(z,c,t) | |
print(z.shape,c.shape,t.shape,output.shape) | |
output_cfg = model.forward_cfg(z,t) | |
print(output_cfg.shape) | |