|
import torch |
|
from torch import nn |
|
from typing import List |
|
from diffusers.models.embeddings import Timesteps, TimestepEmbedding |
|
|
|
|
|
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor: |
|
assert dim % 2 == 0, "The dimension must be even." |
|
|
|
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim |
|
omega = 1.0 / (theta**scale) |
|
|
|
batch_size, seq_length = pos.shape |
|
out = torch.einsum("...n,d->...nd", pos, omega) |
|
cos_out = torch.cos(out) |
|
sin_out = torch.sin(out) |
|
|
|
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1) |
|
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2) |
|
return out.float() |
|
|
|
|
|
class EmbedND(nn.Module): |
|
def __init__(self, theta: int, axes_dim: List[int]): |
|
super().__init__() |
|
self.theta = theta |
|
self.axes_dim = axes_dim |
|
|
|
def forward(self, ids: torch.Tensor) -> torch.Tensor: |
|
n_axes = ids.shape[-1] |
|
emb = torch.cat( |
|
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], |
|
dim=-3, |
|
) |
|
return emb.unsqueeze(2) |
|
|
|
class PatchEmbed(nn.Module): |
|
def __init__( |
|
self, |
|
patch_size=2, |
|
in_channels=4, |
|
out_channels=1024, |
|
): |
|
super().__init__() |
|
self.patch_size = patch_size |
|
self.out_channels = out_channels |
|
self.proj = nn.Linear(in_channels * patch_size * patch_size, out_channels, bias=True) |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
nn.init.xavier_uniform_(m.weight) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, latent): |
|
latent = self.proj(latent) |
|
return latent |
|
|
|
class PooledEmbed(nn.Module): |
|
def __init__(self, text_emb_dim, hidden_size): |
|
super().__init__() |
|
self.pooled_embedder = TimestepEmbedding(in_channels=text_emb_dim, time_embed_dim=hidden_size) |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
nn.init.normal_(m.weight, std=0.02) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, pooled_embed): |
|
return self.pooled_embedder(pooled_embed) |
|
|
|
class TimestepEmbed(nn.Module): |
|
def __init__(self, hidden_size, frequency_embedding_size=256): |
|
super().__init__() |
|
self.time_proj = Timesteps(num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0) |
|
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
nn.init.normal_(m.weight, std=0.02) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, timesteps, wdtype): |
|
t_emb = self.time_proj(timesteps).to(dtype=wdtype) |
|
t_emb = self.timestep_embedder(t_emb) |
|
return t_emb |
|
|
|
class OutEmbed(nn.Module): |
|
def __init__(self, hidden_size, patch_size, out_channels): |
|
super().__init__() |
|
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
|
self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) |
|
self.adaLN_modulation = nn.Sequential( |
|
nn.SiLU(), |
|
nn.Linear(hidden_size, 2 * hidden_size, bias=True) |
|
) |
|
self.apply(self._init_weights) |
|
|
|
def _init_weights(self, m): |
|
if isinstance(m, nn.Linear): |
|
nn.init.zeros_(m.weight) |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def forward(self, x, adaln_input): |
|
shift, scale = self.adaLN_modulation(adaln_input).chunk(2, dim=1) |
|
x = self.norm_final(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) |
|
x = self.linear(x) |
|
return x |