Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import math | |
from modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, FinalLayer | |
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, | |
interpolation_factor: int = 1, max_seq_length: int = 4096): | |
print(f'using rope base theta = {theta}, interpolation factor = {interpolation_factor}') | |
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) | |
# ROPE type-A extention | |
# we choose to use interpolation rather than extrapolation for better position encoding | |
# for scale purposes, t should be a float tensor | |
t = torch.arange(end, device=freqs.device).float() | |
scale = 1.0 / float(interpolation_factor) | |
t *= scale | |
freqs = torch.outer(t, freqs).float() # type: ignore | |
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 | |
# Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb | |
# e.g. rope 1M but seqlen 32k, this will cause gpu memory waste | |
if max_seq_length < end: | |
freqs_cis = freqs_cis[:max_seq_length,].clone() | |
return freqs_cis | |
class TimestepEmbedder(nn.Module): | |
""" | |
Embeds scalar timesteps into vector representations. | |
""" | |
def __init__(self, hidden_size, frequency_embedding_size=256): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(frequency_embedding_size, hidden_size, bias=True), | |
nn.SiLU(), | |
nn.Linear(hidden_size, hidden_size, bias=True), | |
) | |
self.frequency_embedding_size = frequency_embedding_size | |
def timestep_embedding(t, dim, max_period=10000): | |
""" | |
Create sinusoidal timestep embeddings. | |
:param t: a 1-D Tensor of N indices, one per batch element. | |
These may be fractional. | |
:param dim: the dimension of the output. | |
:param max_period: controls the minimum frequency of the embeddings. | |
:return: an (N, D) Tensor of positional embeddings. | |
""" | |
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py | |
half = dim // 2 | |
freqs = torch.exp( | |
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half | |
).float().to(device=t.device) | |
args = t[:, None].float() * freqs[None] | |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
if dim % 2: | |
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) | |
return embedding | |
def forward(self, t): | |
t_freq = self.timestep_embedding(t, self.frequency_embedding_size) | |
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype)) | |
return t_emb | |
class SinusoidalPositionalEmbedding(nn.Module): | |
"""This module produces sinusoidal positional embeddings of any length. | |
Padding symbols are ignored. | |
""" | |
def __init__(self, embedding_dim, padding_idx, init_size=1024): | |
super().__init__() | |
self.embedding_dim = embedding_dim | |
self.padding_idx = padding_idx | |
self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
init_size, | |
embedding_dim, | |
padding_idx, | |
) | |
self.register_buffer('_float_tensor', torch.FloatTensor(1)) | |
def get_embedding(num_embeddings, embedding_dim, padding_idx=None): | |
"""Build sinusoidal embeddings. | |
This matches the implementation in tensor2tensor, but differs slightly | |
from the description in Section 3.5 of "Attention Is All You Need". | |
""" | |
half_dim = embedding_dim // 2 # d/2 | |
emb = math.log(10000) / (half_dim - 1) # 2*log(10000)/(d-2) | |
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, ) | |
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2) | |
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) # shape: (num_embeddings, d) | |
if embedding_dim % 2 == 1: | |
# zero pad | |
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1) | |
if padding_idx is not None: | |
emb[padding_idx, :] = 0 | |
return emb | |
def forward(self, input, incremental_state=None, timestep=None, **kwargs): | |
"""Input is expected to be of size [bsz x seqlen].""" | |
bsz, seq_len = input.shape[:2] | |
max_pos = self.padding_idx + 1 + seq_len | |
if self.weights is None or max_pos > self.weights.size(0): | |
# recompute/expand embeddings if needed | |
self.weights = SinusoidalPositionalEmbedding.get_embedding( | |
max_pos, | |
self.embedding_dim, | |
self.padding_idx, | |
) | |
self.weights = self.weights.to(self._float_tensor) | |
if incremental_state is not None: | |
# positions is the same for every token when decoding a single step | |
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len | |
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1) | |
positions = self.make_positions(input, self.padding_idx) | |
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() # (B, T, dim) | |
def max_positions(self): | |
"""Maximum number of supported positions.""" | |
return int(1e5) # an arbitrary large number | |
def make_positions(self, tensor, padding_idx): | |
"""Replace non-padding symbols with their position numbers. | |
Position numbers begin at padding_idx+1. Padding symbols are ignored. | |
""" | |
# The series of casts and type-conversions here are carefully | |
# balanced to both work with ONNX export and XLA. In particular XLA | |
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know | |
# how to handle the dtype kwarg in cumsum. | |
mask = tensor.ne(padding_idx).int() | |
return ( | |
torch.cumsum(mask, dim=1).type_as(mask) * mask | |
).long() + padding_idx | |
class DiTPrefix(nn.Module): | |
""" | |
Diffusion model with a Transformer backbone. | |
""" | |
def __init__( | |
self, | |
input_size, | |
output_size, | |
semantic_vocab_size, | |
hidden_size=1024, | |
depth=12, | |
num_heads=4, | |
# mlp related | |
mlp_ratio=4.0, | |
ffn_type="conv1d_conv1d", | |
ffn_gated_glu=True, | |
ffn_act_layer="gelu", | |
ffn_conv_kernel_size=5, | |
# rope | |
use_rope=False, | |
rope_params={ | |
"max_position_embeddings": 4096, | |
"rope_base": 10000.0, | |
"rope_interpolation_factor": 1.0, | |
}, | |
position_embedding_type="sincos", | |
max_seq_len=4096, | |
prompt_cfg_dropout=0.0 | |
): | |
super().__init__() | |
self.num_heads = num_heads | |
self.prompt_cfg_dropout = prompt_cfg_dropout | |
self.t_embedder = TimestepEmbedder(hidden_size) | |
self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size) | |
self.input_linear = nn.Linear(input_size, hidden_size) | |
# position embedding | |
if position_embedding_type == "learnable": | |
self.position_embedding = nn.Embedding(max_seq_len+1, hidden_size) | |
elif position_embedding_type == "sincos": | |
self.position_embedding = SinusoidalPositionalEmbedding(hidden_size, 0, max_seq_len+1) | |
elif position_embedding_type == "skip": | |
self.position_embedding = None | |
else: | |
raise NotImplementedError("Position embedding type: {} not implemented.".format(position_embedding_type)) | |
self.use_rope = use_rope | |
if self.use_rope: | |
assert hidden_size % num_heads == 0, "Hidden size must be divisible by num_heads for rope position embedding." | |
rope_dim = hidden_size // num_heads | |
self.rotary_pos_emb = precompute_freqs_cis( | |
rope_dim, rope_params["max_position_embeddings"], | |
theta=rope_params["rope_base"], | |
interpolation_factor=rope_params["rope_interpolation_factor"], | |
max_seq_length=max_seq_len | |
) | |
self.blocks = nn.ModuleList([ | |
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, | |
ffn_type=ffn_type, ffn_conv_kernel_size=ffn_conv_kernel_size, ffn_gated_glu=ffn_gated_glu, ffn_act_layer=ffn_act_layer) for _ in range(depth) | |
]) | |
self.final_layer = FinalLayer(hidden_size, output_size) | |
self.initialize_weights() | |
def initialize_weights(self): | |
# Initialize transformer layers: | |
def _basic_init(module): | |
if isinstance(module, nn.Linear): | |
torch.nn.init.xavier_uniform_(module.weight) | |
if module.bias is not None: | |
nn.init.constant_(module.bias, 0) | |
self.apply(_basic_init) | |
# Initialize timestep embedding MLP: | |
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) | |
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) | |
# Zero-out adaLN modulation layers in DiT blocks: | |
for block in self.blocks: | |
nn.init.constant_(block.adaLN_modulation[-1].weight, 0) | |
nn.init.constant_(block.adaLN_modulation[-1].bias, 0) | |
# Zero-out output layers: | |
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) | |
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) | |
nn.init.constant_(self.final_layer.linear.weight, 0) | |
nn.init.constant_(self.final_layer.linear.bias, 0) | |
def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, incremental_state=None, nopadding=True): | |
""" | |
Forward pass of DiT. | |
x: (N, T, C) tensor of inputs (latent representations of speech) | |
position_ids: (N, T) tensor of positional indices | |
t: (N,) tensor of diffusion timesteps | |
condition: (N, T) tensor of semantic tokens | |
seq_len: (N,) tensor of sequence lengths | |
""" | |
condition = self.semantic_token_embedding(condition) # (N, T, D) | |
x = self.input_linear(x) | |
if self.position_embedding is not None: | |
position_emb = self.position_embedding(position_ids) | |
x = x + position_emb | |
# ROPE | |
if self.use_rope: | |
bsz, seqlen = position_ids.shape | |
if self.rotary_pos_emb.device != position_ids.device: | |
self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device) | |
rotary_pos_emb = torch.zeros((bsz, seqlen, self.rotary_pos_emb.shape[1]), | |
dtype=self.rotary_pos_emb.dtype, | |
device=self.rotary_pos_emb.device) | |
for b in range(bsz): | |
cur_rope = rotary_pos_emb[b] | |
cur_position_ids = position_ids[b] | |
cur_rope[:] = self.rotary_pos_emb[cur_position_ids] | |
else: | |
rotary_pos_emb = None | |
t = self.t_embedder(t) # (N, D) | |
c = t.unsqueeze(1) + condition # (N, T, D) | |
for block_idx, block in enumerate(self.blocks): | |
# x = block(x, c, attn_mask) # (N, T, D) | |
# XXX mask could be None because we always use full mask | |
if incremental_state is not None: | |
if block_idx not in incremental_state: | |
incremental_state[block_idx] = {} | |
incr = incremental_state[block_idx] | |
else: | |
incr = None | |
x = block(x=x, c=c, seq_len=seq_len, cu_seqlens=cu_seqlens, cu_maxlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, cu_maxlen_k=cu_maxlen_k, mask=mask, rotary_pos_emb=rotary_pos_emb, incremental_state=incr, nopadding=nopadding) | |
x = self.final_layer(x, c) # (N, T, C) | |
return x | |