kaupane's picture
Update models/DiT.py
a7df152 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from timm.models.vision_transformer import PatchEmbed
from huggingface_hub import PyTorchModelHubMixin
class TimestepEmbedder(nn.Module):
"""Module to create timestep's embedding."""
def __init__(self,hidden_size,frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size,hidden_size),
nn.SiLU(),
nn.Linear(hidden_size,hidden_size)
)
self.frequency_embedding_size = frequency_embedding_size
def forward(self, t):
half = self.frequency_embedding_size // 2
freqs = torch.exp(
-math.log(10000) * torch.arange(start=0,end=half) / half
).to(device=t.device)
args = torch.einsum('i,j->ij', t, freqs.to(t.device))
freqs = torch.cat([torch.cos(args),torch.sin(args)],dim=-1)
mlp_dtype = next(self.mlp.parameters()).dtype
freqs_casted = freqs.to(mlp_dtype)
return self.mlp(freqs_casted)
class ViTAttn(nn.Module):
def __init__(self,hidden_size,num_heads):
super().__init__()
self.attn = nn.MultiheadAttention(hidden_size,num_heads,bias=True,add_bias_kv=True,batch_first=True)
def forward(self,x):
attn, _ = self.attn(x,x,x)
return attn
class DiTBlock(nn.Module):
"""
DiT Block with adaptive layer norm zero (adaLN-Zero) conditioning.
Using post-norm
"""
def __init__(self,hidden_size,num_heads):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size,elementwise_affine=False,eps=1e-6)
self.attn = ViTAttn(hidden_size,num_heads)
self.norm2 = nn.LayerNorm(hidden_size,elementwise_affine=False,eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(hidden_size,4*hidden_size),
nn.GELU(approximate="tanh"),
nn.Linear(4*hidden_size,hidden_size)
)
self.adaLN = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size,6*hidden_size)
)
def forward(self,x,c):
gamma_1,beta_1,alpha_1,gamma_2,beta_2,alpha_2 = self.adaLN(c).chunk(6,dim=1)
x = self.norm1(x + alpha_1.unsqueeze(1) * self.attn(x))
x = x * (1+gamma_1.unsqueeze(1)) + beta_1.unsqueeze(1)
x = self.norm2(x + alpha_2.unsqueeze(1) * self.mlp(x))
x = x * (1+gamma_2.unsqueeze(1)) + beta_2.unsqueeze(1)
return x
class DiT(nn.Module,
PyTorchModelHubMixin):
def __init__(self,
num_blocks=10,
hidden_size=640,
num_heads=10,
patch_size=2,
num_channels=4,
img_size=32,
num_genres=42,
num_styles=137):
super().__init__()
self.hidden_size = hidden_size
self.patch_size = patch_size
self.num_channels = num_channels
self.seq_len = (img_size // patch_size)**2
self.img_size = img_size
self.blocks = nn.ModuleList(
DiTBlock(hidden_size,num_heads) for _ in range(num_blocks)
)
self.timestep_embed = TimestepEmbedder(hidden_size)
self.num_genres = num_genres
self.num_styles = num_styles
self.genre_condition = nn.Embedding(num_genres+1,hidden_size) # +1 for null condition
self.style_condition = nn.Embedding(num_styles+1,hidden_size)
self.pos_embed = nn.Parameter(torch.zeros(1, self.seq_len, hidden_size))
patch_dim = num_channels * patch_size * patch_size
self.proj_in = nn.Linear(patch_dim,hidden_size)
self.proj_out = nn.Linear(hidden_size,patch_dim)
self.norm_out = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.adaLN_final = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2*hidden_size)
)
self.initialize_weights()
def initialize_weights(self):
nn.init.normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.proj_out.weight, std=0.02)
nn.init.zeros_(self.proj_out.bias)
nn.init.normal_(self.proj_in.weight, std=0.02)
nn.init.zeros_(self.proj_in.bias)
nn.init.normal_(self.timestep_embed.mlp[0].weight, std=0.02)
nn.init.zeros_(self.timestep_embed.mlp[0].bias)
nn.init.normal_(self.timestep_embed.mlp[2].weight, std=0.02)
nn.init.zeros_(self.timestep_embed.mlp[2].bias)
for block in self.blocks:
nn.init.zeros_(block.adaLN[-1].weight)
nn.init.zeros_(block.adaLN[-1].bias)
nn.init.zeros_(self.adaLN_final[-1].weight)
nn.init.zeros_(self.adaLN_final[-1].bias)
nn.init.normal_(self.genre_condition.weight, std=0.02)
nn.init.normal_(self.style_condition.weight, std=0.02)
def patchify(self,z):
"""
from (batch_size,6,32,32) -> (batch_size,256,24) -> (batch_size,256,hidden_size)
"""
b,_,_,_ = z.shape
c = self.num_channels
p = self.patch_size
z = z.unfold(2,p,p).unfold(3,p,p) # (b,c,h//p,p,w//p,p)
z = z.contiguous().view(b,c,-1,p,p) # (b,c,hw//p**2,p,p)
z = torch.einsum('bcapq->bacpq',z).contiguous().view(b,-1,c*p**2) # (b,hw//p**2,c*p**2)
return self.proj_in(z) # (b,hw//p**2,hidden_size)
def unpatchify(self,z):
"""
from (batch_size,256,hidden_size) -> (batch_size,256,24) -> (batch_size,6,32,32)
"""
b,_,_ = z.shape
c = self.num_channels
p = self.patch_size
s = int(self.seq_len ** 0.5)
i = self.img_size
z = self.proj_out(z) # (b,hw//p**2,c*p**2)
z = z.view(b,s,s,c,p,p) # (b,h/p,w/p,c,p,p)
z = torch.einsum('befcpq->bcepfq',z) # (b,c,h/p,p,w/p,p)
z = z.contiguous().view(b,c,i,i)
return z
def forward(self,z,t,g,s):
t_embed = self.timestep_embed(t) # t_embed: (batch_size, hidden_size)
g_embed = self.genre_condition(g)
s_embed = self.style_condition(s)
c = t_embed + g_embed + s_embed
z = self.patchify(z)
z = z + self.pos_embed
for block in self.blocks:
z = block(z,c)
gamma, beta = self.adaLN_final(c).chunk(2,dim=-1)
z = self.norm_out(z)
z = z * (1+gamma.unsqueeze(1)) + beta.unsqueeze(1)
return self.unpatchify(z)
if __name__ == "__main__":
model = DiT(1,768,12,2,6,32)
z = torch.randn(2,6,32,32)
c = torch.randn(2,768)
t = torch.randint(0,1000,(2,))
output = model(z,c,t)
print(z.shape,c.shape,t.shape,output.shape)
output_cfg = model.forward_cfg(z,t)
print(output_cfg.shape)