jzq11111's picture
Upload folder using huggingface_hub
a3e05e8 verified
raw
history blame contribute delete
12.2 kB
import torch
import torch.nn as nn
import math
from modules.audio_detokenizer.flow_matching.dit_block import DiTBlock, FinalLayer
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0,
interpolation_factor: int = 1, max_seq_length: int = 4096):
print(f'using rope base theta = {theta}, interpolation factor = {interpolation_factor}')
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
# ROPE type-A extention
# we choose to use interpolation rather than extrapolation for better position encoding
# for scale purposes, t should be a float tensor
t = torch.arange(end, device=freqs.device).float()
scale = 1.0 / float(interpolation_factor)
t *= scale
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
# Sometimes, we don't need so many rope emb as seq_len is smaller than max_pos_emb
# e.g. rope 1M but seqlen 32k, this will cause gpu memory waste
if max_seq_length < end:
freqs_cis = freqs_cis[:max_seq_length,].clone()
return freqs_cis
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).float().to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
return t_emb
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size,
embedding_dim,
padding_idx,
)
self.register_buffer('_float_tensor', torch.FloatTensor(1))
@staticmethod
def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2 # d/2
emb = math.log(10000) / (half_dim - 1) # 2*log(10000)/(d-2)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb) # -2i/(d-2)*log(10000); i from 0 to (d-2)/2; shape: (d/2, )
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(1) * emb.unsqueeze(0) # pos/[1000 ** (2i/(d-2))]; shape: (num_embeddings, d/2)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1) # shape: (num_embeddings, d)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(self, input, incremental_state=None, timestep=None, **kwargs):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = input.shape[:2]
max_pos = self.padding_idx + 1 + seq_len
if self.weights is None or max_pos > self.weights.size(0):
# recompute/expand embeddings if needed
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos,
self.embedding_dim,
self.padding_idx,
)
self.weights = self.weights.to(self._float_tensor)
if incremental_state is not None:
# positions is the same for every token when decoding a single step
pos = timestep.view(-1)[0] + 1 if timestep is not None else seq_len
return self.weights[self.padding_idx + pos, :].expand(bsz, 1, -1)
positions = self.make_positions(input, self.padding_idx)
return self.weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1).detach() # (B, T, dim)
def max_positions(self):
"""Maximum number of supported positions."""
return int(1e5) # an arbitrary large number
def make_positions(self, tensor, padding_idx):
"""Replace non-padding symbols with their position numbers.
Position numbers begin at padding_idx+1. Padding symbols are ignored.
"""
# The series of casts and type-conversions here are carefully
# balanced to both work with ONNX export and XLA. In particular XLA
# prefers ints, cumsum defaults to output longs, and ONNX doesn't know
# how to handle the dtype kwarg in cumsum.
mask = tensor.ne(padding_idx).int()
return (
torch.cumsum(mask, dim=1).type_as(mask) * mask
).long() + padding_idx
class DiTPrefix(nn.Module):
"""
Diffusion model with a Transformer backbone.
"""
def __init__(
self,
input_size,
output_size,
semantic_vocab_size,
hidden_size=1024,
depth=12,
num_heads=4,
# mlp related
mlp_ratio=4.0,
ffn_type="conv1d_conv1d",
ffn_gated_glu=True,
ffn_act_layer="gelu",
ffn_conv_kernel_size=5,
# rope
use_rope=False,
rope_params={
"max_position_embeddings": 4096,
"rope_base": 10000.0,
"rope_interpolation_factor": 1.0,
},
position_embedding_type="sincos",
max_seq_len=4096,
prompt_cfg_dropout=0.0
):
super().__init__()
self.num_heads = num_heads
self.prompt_cfg_dropout = prompt_cfg_dropout
self.t_embedder = TimestepEmbedder(hidden_size)
self.semantic_token_embedding = nn.Embedding(semantic_vocab_size, hidden_size)
self.input_linear = nn.Linear(input_size, hidden_size)
# position embedding
if position_embedding_type == "learnable":
self.position_embedding = nn.Embedding(max_seq_len+1, hidden_size)
elif position_embedding_type == "sincos":
self.position_embedding = SinusoidalPositionalEmbedding(hidden_size, 0, max_seq_len+1)
elif position_embedding_type == "skip":
self.position_embedding = None
else:
raise NotImplementedError("Position embedding type: {} not implemented.".format(position_embedding_type))
self.use_rope = use_rope
if self.use_rope:
assert hidden_size % num_heads == 0, "Hidden size must be divisible by num_heads for rope position embedding."
rope_dim = hidden_size // num_heads
self.rotary_pos_emb = precompute_freqs_cis(
rope_dim, rope_params["max_position_embeddings"],
theta=rope_params["rope_base"],
interpolation_factor=rope_params["rope_interpolation_factor"],
max_seq_length=max_seq_len
)
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio,
ffn_type=ffn_type, ffn_conv_kernel_size=ffn_conv_kernel_size, ffn_gated_glu=ffn_gated_glu, ffn_act_layer=ffn_act_layer) for _ in range(depth)
])
self.final_layer = FinalLayer(hidden_size, output_size)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP:
nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers in DiT blocks:
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers:
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, position_ids, t, condition, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, incremental_state=None, nopadding=True):
"""
Forward pass of DiT.
x: (N, T, C) tensor of inputs (latent representations of speech)
position_ids: (N, T) tensor of positional indices
t: (N,) tensor of diffusion timesteps
condition: (N, T) tensor of semantic tokens
seq_len: (N,) tensor of sequence lengths
"""
condition = self.semantic_token_embedding(condition) # (N, T, D)
x = self.input_linear(x)
if self.position_embedding is not None:
position_emb = self.position_embedding(position_ids)
x = x + position_emb
# ROPE
if self.use_rope:
bsz, seqlen = position_ids.shape
if self.rotary_pos_emb.device != position_ids.device:
self.rotary_pos_emb = self.rotary_pos_emb.to(position_ids.device)
rotary_pos_emb = torch.zeros((bsz, seqlen, self.rotary_pos_emb.shape[1]),
dtype=self.rotary_pos_emb.dtype,
device=self.rotary_pos_emb.device)
for b in range(bsz):
cur_rope = rotary_pos_emb[b]
cur_position_ids = position_ids[b]
cur_rope[:] = self.rotary_pos_emb[cur_position_ids]
else:
rotary_pos_emb = None
t = self.t_embedder(t) # (N, D)
c = t.unsqueeze(1) + condition # (N, T, D)
for block_idx, block in enumerate(self.blocks):
# x = block(x, c, attn_mask) # (N, T, D)
# XXX mask could be None because we always use full mask
if incremental_state is not None:
if block_idx not in incremental_state:
incremental_state[block_idx] = {}
incr = incremental_state[block_idx]
else:
incr = None
x = block(x=x, c=c, seq_len=seq_len, cu_seqlens=cu_seqlens, cu_maxlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, cu_maxlen_k=cu_maxlen_k, mask=mask, rotary_pos_emb=rotary_pos_emb, incremental_state=incr, nopadding=nopadding)
x = self.final_layer(x, c) # (N, T, C)
return x