Spaces:
Running
on
Zero
Running
on
Zero
from typing import Any | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch import Tensor | |
from torch.nn import RMSNorm | |
from .config import DiaConfig | |
def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]: | |
return tuple(ax if ax >= 0 else ndim + ax for ax in axes) | |
def _str_to_dtype(dtype_str: str) -> torch.dtype | None: | |
# Allow None for default behavior | |
if dtype_str is None or dtype_str.lower() == "none": | |
return None | |
if dtype_str == "float32": | |
return torch.float32 | |
elif dtype_str == "float16": | |
return torch.float16 | |
elif dtype_str == "bfloat16": | |
return torch.bfloat16 | |
else: | |
raise ValueError(f"Unsupported dtype string: {dtype_str}") | |
class DenseGeneral(nn.Module): | |
""" | |
PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init. | |
Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot | |
for the generalized matrix multiplication. Weight/bias shapes are calculated | |
and parameters created during initialization based on config. | |
`load_weights` validates shapes and copies data. | |
Attributes: | |
axis (Tuple[int, ...]): Input axis or axes to contract. | |
in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`. | |
out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims). | |
use_bias (bool): Whether to add a bias term. | |
weight (nn.Parameter): The kernel parameter. | |
bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True). | |
""" | |
def __init__( | |
self, | |
in_shapes: tuple[int, ...], | |
out_features: tuple[int, ...], | |
axis: tuple[int, ...] = (-1,), | |
dtype: torch.dtype | None = None, | |
weight_dtype: torch.dtype | None = None, | |
device: torch.device | None = None, | |
): | |
super().__init__() | |
self.in_shapes = in_shapes | |
self.out_features = out_features | |
self.axis = axis | |
self.dtype = dtype | |
self.kernel_shape = self.in_shapes + self.out_features | |
factory_kwargs = {"device": device, "dtype": weight_dtype} | |
self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs)) | |
self.register_parameter("bias", None) | |
def forward(self, inputs: Tensor) -> Tensor: | |
norm_axis = _normalize_axes(self.axis, inputs.ndim) | |
kernel_contract_axes = tuple(range(len(norm_axis))) | |
output = torch.tensordot( | |
inputs.float(), | |
self.weight.float(), | |
dims=(norm_axis, kernel_contract_axes), | |
).to(inputs.dtype) | |
return output | |
def get_activation_fn(activation_string: str) -> nn.Module: # Return Module instance | |
"""Maps activation string to PyTorch activation function module.""" | |
if activation_string == "gelu": | |
return nn.GELU() | |
elif activation_string == "relu": | |
return nn.ReLU() | |
elif activation_string == "silu" or activation_string == "swish": | |
return nn.SiLU() | |
elif activation_string == "linear": | |
return nn.Identity() | |
else: | |
raise ValueError(f"Unsupported activation function: {activation_string}") | |
class MlpBlock(nn.Module): | |
"""MLP block using DenseGeneral.""" | |
def __init__( | |
self, | |
config: DiaConfig, | |
embed_dim: int, | |
intermediate_dim: int, | |
dropout_rate: float, | |
activations: list[str] = ["silu", "linear"], | |
use_pre_norm: bool = False, | |
): | |
super().__init__() | |
self.use_pre_norm = use_pre_norm | |
num_activations = len(activations) | |
compute_dtype = _str_to_dtype(config.training.dtype) | |
weight_dtype = _str_to_dtype(config.model.weight_dtype) | |
self.dtype = compute_dtype | |
# Assume default device for now, could be passed in config | |
if use_pre_norm: | |
self.pre_norm = RMSNorm( | |
embed_dim, | |
eps=config.model.normalization_layer_epsilon, | |
dtype=torch.float32, | |
) | |
self.wi_fused = DenseGeneral( | |
in_shapes=(embed_dim,), | |
out_features=( | |
num_activations, | |
intermediate_dim, | |
), | |
axis=(-1,), | |
dtype=compute_dtype, | |
weight_dtype=weight_dtype, | |
) | |
self.activation_fn_0 = get_activation_fn(activations[0]) # silu | |
self.activation_fn_1 = get_activation_fn(activations[1]) # linear | |
self.dropout = nn.Dropout(dropout_rate) | |
# Output layer using DenseGeneral | |
self.wo = DenseGeneral( | |
in_shapes=(intermediate_dim,), | |
out_features=(embed_dim,), | |
axis=(-1,), | |
dtype=compute_dtype, | |
weight_dtype=weight_dtype, | |
) | |
def forward(self, x: torch.Tensor, deterministic: bool) -> torch.Tensor: | |
"""Forward pass.""" | |
if self.use_pre_norm and hasattr(self, "pre_norm"): | |
x = self.pre_norm(x) | |
fused_x = self.wi_fused(x) | |
gate_input = fused_x[..., 0, :] | |
up_input = fused_x[..., 1, :] | |
gate = self.activation_fn_0(gate_input) | |
up = self.activation_fn_1(up_input) | |
hidden = torch.mul(gate, up).to(self.dtype) | |
if not deterministic: | |
hidden = self.dropout(hidden) | |
output = self.wo(hidden) | |
return output | |
class RotaryEmbedding(nn.Module): | |
"""Rotary Position Embedding (RoPE) implementation in PyTorch.""" | |
def __init__( | |
self, | |
embedding_dims: int, | |
min_timescale: int = 1, | |
max_timescale: int = 10000, | |
dtype: torch.dtype = torch.float32, | |
): | |
super().__init__() | |
if embedding_dims % 2 != 0: | |
raise ValueError("Embedding dim must be even for RoPE.") | |
self.embedding_dims = embedding_dims | |
self.min_timescale = min_timescale | |
self.max_timescale = max_timescale | |
self.dtype = dtype | |
half_embedding_dim = embedding_dims // 2 | |
fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims | |
self.register_buffer( | |
"timescale", | |
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction, | |
persistent=False, | |
) | |
def extra_repr(self) -> str: | |
s = f"{self.timescale.shape}" | |
return s | |
def forward(self, inputs: torch.Tensor, position: torch.Tensor): | |
"""Applies RoPE.""" | |
position = position.unsqueeze(-1).unsqueeze(-1) | |
timescale = self.timescale.to(inputs.device) | |
sinusoid_inp = position / timescale | |
sin = torch.sin(sinusoid_inp).to(inputs.dtype) | |
cos = torch.cos(sinusoid_inp).to(inputs.dtype) | |
first_half, second_half = torch.chunk(inputs, 2, dim=-1) | |
first_part = first_half * cos - second_half * sin | |
second_part = second_half * cos + first_half * sin | |
return torch.cat((first_part, second_part), dim=-1) | |
class KVCache: | |
def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None): | |
self.k = torch.zeros((2, num_heads, max_len, head_dim), device=device) if k is None else k | |
self.v = torch.zeros((2, num_heads, max_len, head_dim), device=device) if v is None else v | |
self.current_idx = 0 | |
self.max_len = max_len | |
def get_kv_for_attention(self, current_k, current_v): | |
if self.current_idx == 0: | |
return current_k, current_v | |
else: | |
past_k = self.k[:, :, : self.current_idx, :] | |
past_v = self.v[:, :, : self.current_idx, :] | |
attn_k = torch.cat((past_k, current_k), dim=2) | |
attn_v = torch.cat((past_v, current_v), dim=2) | |
return attn_k, attn_v | |
def update_cache(self, k, v): | |
assert self.current_idx < self.max_len | |
self.k[:, :, self.current_idx : self.current_idx + 1, :] = k | |
self.v[:, :, self.current_idx : self.current_idx + 1, :] = v | |
self.current_idx += 1 | |
def prefill_kv(self, k, v): | |
prefill_len = k.shape[2] | |
assert prefill_len <= self.max_len | |
self.k[:, :, :prefill_len, :] = k | |
self.v[:, :, :prefill_len, :] = v | |
self.current_idx = prefill_len | |
class Attention(nn.Module): | |
"""Attention using DenseGeneral.""" | |
def __init__( | |
self, | |
config: DiaConfig, | |
q_embed_dim: int, | |
kv_embed_dim: int, | |
num_query_heads: int, | |
num_kv_heads: int, | |
head_dim: int, | |
dropout_rate: float, | |
is_cross_attn: bool = False, | |
out_embed_dim: int | None = None, | |
): | |
super().__init__() | |
self.num_query_heads = num_query_heads | |
self.num_kv_heads = num_kv_heads | |
self.head_dim = head_dim | |
self.is_cross_attn = is_cross_attn | |
self.dropout_rate = dropout_rate | |
compute_dtype = _str_to_dtype(config.training.dtype) | |
weight_dtype = _str_to_dtype(config.model.weight_dtype) | |
self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim | |
self.projected_query_dim = num_query_heads * head_dim | |
if num_query_heads % num_kv_heads != 0: | |
raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})") | |
self.num_gqa_groups = num_query_heads // num_kv_heads | |
# --- Projection Layers using DenseGeneral --- | |
self.q_proj = DenseGeneral( | |
in_shapes=(q_embed_dim,), | |
out_features=(num_query_heads, head_dim), | |
axis=(-1,), | |
dtype=compute_dtype, | |
weight_dtype=weight_dtype, | |
) | |
self.k_proj = DenseGeneral( | |
in_shapes=(kv_embed_dim,), | |
out_features=(num_kv_heads, head_dim), | |
axis=(-1,), | |
dtype=compute_dtype, | |
weight_dtype=weight_dtype, | |
) | |
self.v_proj = DenseGeneral( | |
in_shapes=(kv_embed_dim,), | |
out_features=(num_kv_heads, head_dim), | |
axis=(-1,), | |
dtype=compute_dtype, | |
weight_dtype=weight_dtype, | |
) | |
self.o_proj = DenseGeneral( | |
in_shapes=(num_query_heads, head_dim), | |
out_features=(self.output_dim,), | |
axis=(-2, -1), | |
dtype=compute_dtype, | |
weight_dtype=weight_dtype, | |
) | |
# --- Rotary Embedding --- | |
self.rotary_emb = RotaryEmbedding( | |
embedding_dims=self.head_dim, | |
min_timescale=config.model.rope_min_timescale, | |
max_timescale=config.model.rope_max_timescale, | |
dtype=compute_dtype, | |
) | |
def forward( | |
self, | |
Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation | |
Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation | |
q_positions: torch.Tensor, # (B, T) | |
kv_positions: torch.Tensor | None = None, # (B, S) | |
deterministic: bool = True, | |
attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others | |
cache: KVCache | None = None, # None in Encoder, KVCache in Decoder | |
prefill: bool = False, # True only when prefilling KV Cache | |
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: | |
""" | |
Performs attention calculation with optional KV caching. | |
Args: | |
Xq: Query tensor (B, T, D). T=1 during single-step decoding. | |
Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn. | |
q_positions: Positions for queries (B, T). | |
kv_positions: Positions for keys/values (B, S). If None, uses q_positions. | |
deterministic: If True, disable dropout. | |
attn_mask: Attention mask. | |
cache: KVCache. | |
prefill: If True, use prefill mode. | |
Returns: | |
A tuple containing: | |
- output: The attention output tensor (B, T, output_dim). | |
- present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv. | |
""" | |
if kv_positions is None: | |
kv_positions = q_positions | |
original_dtype = Xq.dtype | |
Xq_BxTxNxH = self.q_proj(Xq) | |
Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions) | |
Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2) | |
# Input values into attention calculation | |
attn_k: torch.Tensor | None = None | |
attn_v: torch.Tensor | None = None | |
new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None | |
# Decoder Cross Attention | |
if self.is_cross_attn: | |
# Directly use cache (no need to check index) | |
attn_k, attn_v = cache.k, cache.v | |
if attn_k.shape[1] != self.num_query_heads or attn_v.shape[1] != self.num_query_heads: | |
raise ValueError( | |
f"Cross-attention cache head dimension ({attn_k.shape[1]}) " | |
f"does not match num_query_heads ({self.num_query_heads}). " | |
"Cache should be pre-repeated for GQA." | |
) | |
# Self Attention | |
else: | |
Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H) | |
Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H) | |
Xk_BxSxKxH = self.rotary_emb(Xk_BxSxKxH, position=kv_positions) # (B, S, K, H) | |
Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H) | |
Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H) | |
# S=1 for Decode Step | |
if self.num_gqa_groups > 1: | |
Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1) | |
Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1) | |
else: | |
Xk_BxNxSxH = Xk_BxKxSxH | |
Xv_BxNxSxH = Xv_BxKxSxH | |
# Encoder Self Attention | |
if cache is None: | |
attn_k = Xk_BxNxSxH | |
attn_v = Xv_BxNxSxH | |
# Decoder Self Attention | |
else: | |
# In prefill mode, we fill in cache until prefill length | |
if prefill: | |
attn_k, attn_v = Xk_BxNxSxH, Xv_BxNxSxH | |
cache.prefill_kv(attn_k, attn_v) | |
# In decode step, we add current K/V to cache step by step | |
else: | |
new_kv_cache = Xk_BxNxSxH, Xv_BxNxSxH | |
attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH) | |
attn_output = F.scaled_dot_product_attention( | |
Xq_BxNxTxH, | |
attn_k, | |
attn_v, | |
attn_mask=attn_mask, | |
dropout_p=self.dropout_rate if not deterministic else 0.0, | |
scale=1.0, | |
) | |
attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H) | |
output = self.o_proj(attn_output) | |
return output.to(original_dtype), new_kv_cache | |
class EncoderLayer(nn.Module): | |
"""Transformer Encoder Layer using DenseGeneral.""" | |
def __init__(self, config: DiaConfig): | |
super().__init__() | |
self.config = config | |
model_config = config.model | |
enc_config = config.model.encoder | |
embed_dim = enc_config.n_embd | |
self.pre_sa_norm = RMSNorm( | |
embed_dim, | |
eps=model_config.normalization_layer_epsilon, | |
dtype=torch.float32, | |
) | |
self.self_attention = Attention( | |
config=config, | |
q_embed_dim=embed_dim, | |
kv_embed_dim=embed_dim, | |
num_query_heads=enc_config.n_head, | |
num_kv_heads=enc_config.n_head, | |
head_dim=enc_config.head_dim, | |
dropout_rate=model_config.dropout, | |
is_cross_attn=False, | |
out_embed_dim=embed_dim, | |
) | |
self.post_sa_norm = RMSNorm( | |
embed_dim, | |
eps=model_config.normalization_layer_epsilon, | |
dtype=torch.float32, | |
) | |
self.mlp = MlpBlock( | |
config=config, | |
embed_dim=embed_dim, | |
intermediate_dim=enc_config.n_hidden, | |
activations=enc_config.mlp_activations, | |
dropout_rate=model_config.dropout, | |
use_pre_norm=enc_config.use_pre_norm, | |
) | |
self.dropout = nn.Dropout(model_config.dropout) | |
def forward( | |
self, | |
x: torch.Tensor, | |
src_positions: torch.Tensor | None = None, | |
deterministic: bool = True, | |
attn_mask: torch.Tensor | None = None, | |
) -> torch.Tensor: | |
residual = x | |
x_norm = self.pre_sa_norm(x) | |
sa_out, _ = self.self_attention( | |
Xq=x_norm, | |
Xkv=x_norm, | |
q_positions=src_positions, | |
kv_positions=src_positions, | |
deterministic=deterministic, | |
attn_mask=attn_mask, | |
) | |
x = residual + sa_out | |
residual = x | |
x_norm = self.post_sa_norm(x) | |
mlp_out = self.mlp(x_norm, deterministic=deterministic) | |
x = residual + mlp_out | |
if not deterministic: | |
x = self.dropout(x) | |
return x | |
class Encoder(nn.Module): | |
"""Transformer Encoder Stack using DenseGeneral.""" | |
def __init__(self, config: DiaConfig): | |
super().__init__() | |
self.config = config | |
model_config = config.model | |
enc_config = config.model.encoder | |
compute_dtype = _str_to_dtype(config.training.dtype) | |
self.embedding = nn.Embedding( | |
model_config.src_vocab_size, | |
enc_config.n_embd, | |
dtype=compute_dtype, | |
) | |
self.dropout = nn.Dropout(model_config.dropout) | |
self.layers = nn.ModuleList([EncoderLayer(config=config) for _ in range(enc_config.n_layer)]) | |
self.norm = RMSNorm( | |
enc_config.n_embd, | |
eps=model_config.normalization_layer_epsilon, | |
dtype=torch.float32, | |
) | |
def forward( | |
self, | |
x_ids: torch.Tensor, | |
src_positions: torch.Tensor | None = None, | |
deterministic: bool = True, | |
attn_mask: torch.Tensor | None = None, | |
) -> torch.Tensor: | |
x = self.embedding(x_ids) | |
if not deterministic: | |
x = self.dropout(x) | |
for layer in self.layers: | |
x = layer( | |
x, | |
src_positions=src_positions, | |
deterministic=deterministic, | |
attn_mask=attn_mask, | |
) | |
x = self.norm(x) | |
if not deterministic: | |
x = self.dropout(x) | |
return x | |
class DecoderLayer(nn.Module): | |
"""Transformer Decoder Layer using DenseGeneral.""" | |
def __init__(self, config: DiaConfig): | |
super().__init__() | |
self.config = config | |
model_config = config.model | |
dec_config = config.model.decoder | |
enc_config = config.model.encoder | |
dec_embed_dim = dec_config.n_embd | |
enc_embed_dim = enc_config.n_embd | |
# Norms | |
self.pre_sa_norm = RMSNorm( | |
dec_embed_dim, | |
eps=model_config.normalization_layer_epsilon, | |
dtype=torch.float32, | |
) | |
self.pre_ca_norm = RMSNorm( | |
dec_embed_dim, | |
eps=model_config.normalization_layer_epsilon, | |
dtype=torch.float32, | |
) | |
self.pre_mlp_norm = RMSNorm( | |
dec_embed_dim, | |
eps=model_config.normalization_layer_epsilon, | |
dtype=torch.float32, | |
) | |
# Self-Attention (GQA) with Causal Masking | |
self.self_attention = Attention( | |
config=config, | |
q_embed_dim=dec_embed_dim, | |
kv_embed_dim=dec_embed_dim, | |
num_query_heads=dec_config.gqa_query_heads, | |
num_kv_heads=dec_config.kv_heads, | |
head_dim=dec_config.gqa_head_dim, | |
dropout_rate=model_config.dropout, | |
is_cross_attn=False, | |
out_embed_dim=dec_embed_dim, | |
) | |
# Cross-Attention (MHA) | |
self.cross_attention = Attention( | |
config=config, | |
q_embed_dim=dec_embed_dim, | |
kv_embed_dim=enc_embed_dim, # Note kv_embed_dim | |
num_query_heads=dec_config.cross_query_heads, | |
num_kv_heads=dec_config.cross_query_heads, | |
head_dim=dec_config.cross_head_dim, | |
dropout_rate=model_config.dropout, | |
is_cross_attn=True, | |
out_embed_dim=dec_embed_dim, | |
) | |
# MLP | |
self.mlp = MlpBlock( | |
config=config, | |
embed_dim=dec_embed_dim, | |
intermediate_dim=dec_config.n_hidden, | |
activations=dec_config.mlp_activations, | |
dropout_rate=model_config.dropout, | |
use_pre_norm=dec_config.use_pre_norm, | |
) | |
def forward( | |
self, | |
x: torch.Tensor, | |
encoder_out: torch.Tensor, | |
tgt_positions: torch.Tensor, | |
src_positions: torch.Tensor | None, | |
deterministic: bool, | |
self_attn_mask: torch.Tensor, | |
cross_attn_mask: torch.Tensor, | |
self_attn_cache: KVCache, | |
cross_attn_cache: KVCache, | |
prefill: bool = False, | |
) -> torch.Tensor: | |
residual = x | |
x_norm = self.pre_sa_norm(x) | |
sa_out, new_kv_cache = self.self_attention( | |
Xq=x_norm, # (2, 1, D) | |
Xkv=x_norm, # (2, 1, D) | |
q_positions=tgt_positions, # (2, 1) | |
kv_positions=tgt_positions, # (2, 1) | |
deterministic=deterministic, | |
attn_mask=self_attn_mask, # (2, 1, 1, S_max) | |
cache=self_attn_cache, | |
prefill=prefill, | |
) | |
x = residual + sa_out | |
# 2. Cross-Attention | |
residual = x | |
x_norm = self.pre_ca_norm(x) | |
ca_out, _ = self.cross_attention( | |
Xq=x_norm, | |
Xkv=encoder_out, | |
q_positions=tgt_positions, | |
kv_positions=src_positions, | |
deterministic=deterministic, | |
attn_mask=cross_attn_mask, | |
cache=cross_attn_cache, | |
) | |
x = residual + ca_out | |
# 3. MLP | |
residual = x | |
x_norm = self.pre_mlp_norm(x) | |
mlp_out = self.mlp(x_norm, deterministic=deterministic) | |
x = residual + mlp_out | |
return x, new_kv_cache | |
class Decoder(nn.Module): | |
"""Transformer Decoder Stack using DenseGeneral.""" | |
def __init__(self, config: DiaConfig): | |
super().__init__() | |
self.config = config | |
model_config = config.model | |
dec_config = config.model.decoder | |
train_config = config.training | |
data_config = config.data | |
compute_dtype = _str_to_dtype(config.training.dtype) | |
weight_dtype = _str_to_dtype(config.model.weight_dtype) | |
self.num_channels = data_config.channels | |
self.num_layers = dec_config.n_layer | |
self.embeddings = nn.ModuleList( | |
[ | |
nn.Embedding(model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype) | |
for _ in range(self.num_channels) | |
] | |
) | |
self.dropout = nn.Dropout(model_config.dropout) | |
self.layers = nn.ModuleList([DecoderLayer(config=config) for _ in range(self.num_layers)]) | |
self.norm = RMSNorm( | |
dec_config.n_embd, | |
eps=model_config.normalization_layer_epsilon, | |
dtype=torch.float32, | |
) | |
# Final Logits Projection using DenseGeneral | |
self.logits_dense = DenseGeneral( | |
in_shapes=(dec_config.n_embd,), | |
out_features=(self.num_channels, model_config.tgt_vocab_size), | |
axis=(-1,), | |
dtype=(torch.float32 if train_config.logits_dot_in_fp32 else compute_dtype), | |
weight_dtype=weight_dtype, | |
) | |
self.logits_in_fp32 = train_config.logits_dot_in_fp32 | |
def precompute_cross_attention_kv( | |
self, | |
max_len: int, | |
encoder_out: torch.Tensor, # (B, S, E) | |
src_positions: torch.Tensor | None, # (B, S) | |
) -> list[KVCache]: | |
""" | |
Computes the Key and Value tensors for cross-attention for each layer from the encoder output. | |
""" | |
per_layer_kv_cache: list[KVCache] = [] | |
for layer in self.layers: | |
cross_attn_module = layer.cross_attention | |
k_proj = cross_attn_module.k_proj(encoder_out) | |
v_proj = cross_attn_module.v_proj(encoder_out) | |
k_proj = cross_attn_module.rotary_emb(k_proj, position=src_positions) | |
k = k_proj.transpose(1, 2) | |
v = v_proj.transpose(1, 2) | |
per_layer_kv_cache.append( | |
KVCache( | |
cross_attn_module.num_kv_heads, | |
max_len, | |
cross_attn_module.head_dim, | |
k.device, | |
k=k, | |
v=v, | |
) | |
) | |
return per_layer_kv_cache | |
def decode_step( | |
self, | |
tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C] | |
tgt_pos_Bx1: torch.Tensor, # [B, 1] | |
encoder_out: torch.Tensor, # [B, S, E] | |
self_attn_mask: Any, # None | |
cross_attn_mask: torch.Tensor, # [B, 1, 1, S] | |
self_attention_cache: list[KVCache], | |
cross_attention_cache: list[KVCache], | |
) -> torch.Tensor: | |
""" | |
Performs a single decoding step, managing KV caches layer by layer. | |
Returns: | |
A tuple containing: | |
- logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32. | |
""" | |
assert self_attn_mask is None, "Self-attention mask should be None, kept for pattern" | |
x = None | |
for i in range(self.num_channels): | |
channel_tokens = tgt_ids_Bx1xC[..., i] | |
channel_embed = self.embeddings[i](channel_tokens) | |
x = channel_embed if x is None else x + channel_embed | |
new_cache = [] | |
for i, layer in enumerate(self.layers): | |
self_cache = self_attention_cache[i] | |
cross_cache = cross_attention_cache[i] | |
x, new_kv_cache = layer( | |
x, # (2, 1, D) | |
encoder_out, # (2, S, E) | |
src_positions=None, # CA KV is already computed | |
tgt_positions=tgt_pos_Bx1, # (2, 1) | |
deterministic=True, | |
self_attn_mask=None, | |
cross_attn_mask=cross_attn_mask, | |
self_attn_cache=self_cache, | |
cross_attn_cache=cross_cache, | |
) | |
new_cache.append(new_kv_cache) | |
x = self.norm(x) | |
logits_Bx1xCxV = self.logits_dense(x) | |
return logits_Bx1xCxV.to(torch.float32), new_cache | |
def forward( | |
self, | |
tgt_ids_BxTxC: torch.Tensor, | |
encoder_out: torch.Tensor, | |
tgt_positions: torch.Tensor, | |
src_positions: torch.Tensor, | |
deterministic: bool, | |
self_attn_mask: torch.Tensor, | |
cross_attn_mask: torch.Tensor, | |
self_attention_cache: list[KVCache], | |
cross_attention_cache: list[KVCache], | |
) -> torch.Tensor: | |
""" | |
Forward pass for the Decoder stack, managing KV caches. | |
Args: | |
tgt_ids_BxTxC: Target token IDs (B, T, C). | |
encoder_out: Output from the encoder (B, S, E). | |
tgt_positions: Positions for target sequence (B, T). | |
src_positions: Positions for source sequence (B, S). | |
deterministic: Disable dropout if True. | |
self_attn_mask: Mask for self-attention. | |
cross_attn_mask: Mask for cross-attention. | |
past_key_values: List containing the self-attention KV cache for each layer | |
from the previous decoding step. `len(past_key_values)` should | |
equal `num_layers`. | |
precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache | |
derived from `encoder_out`. This is passed identically | |
to all layers. | |
Returns: | |
A tuple containing: | |
- logits: The final output logits (B, T, C * V), cast to float32. | |
- present_key_values: A list containing the updated self-attention KV cache | |
for each layer for the *current* decoding step. | |
""" | |
_, _, num_channels_in = tgt_ids_BxTxC.shape | |
assert num_channels_in == self.num_channels, "Input channels mismatch" | |
# Embeddings | |
x = None | |
for i in range(self.num_channels): | |
channel_tokens = tgt_ids_BxTxC[..., i] | |
channel_embed = self.embeddings[i](channel_tokens) | |
x = channel_embed if x is None else x + channel_embed | |
if not deterministic: | |
x = self.dropout(x) | |
for i, layer in enumerate(self.layers): | |
x, _ = layer( | |
x, | |
encoder_out, | |
tgt_positions=tgt_positions, | |
src_positions=src_positions, | |
deterministic=deterministic, | |
self_attn_mask=self_attn_mask, | |
cross_attn_mask=cross_attn_mask, | |
self_attn_cache=self_attention_cache[i], | |
cross_attn_cache=cross_attention_cache[i], | |
prefill=True, | |
) | |
# Final Norm | |
x = self.norm(x) | |
logits_BxTxCxV = self.logits_dense(x) | |
return logits_BxTxCxV.to(torch.float32) | |
class DiaModel(nn.Module): | |
"""PyTorch Dia Model using DenseGeneral.""" | |
def __init__(self, config: DiaConfig): | |
super().__init__() | |
self.config = config | |
self.encoder = Encoder(config) | |
self.decoder = Decoder(config) | |
def forward( | |
self, | |
src_BxS: torch.Tensor, | |
tgt_BxTxC: torch.Tensor, | |
src_positions: torch.Tensor | None = None, | |
tgt_positions: torch.Tensor | None = None, | |
enc_self_attn_mask: torch.Tensor | None = None, | |
dec_self_attn_mask: torch.Tensor | None = None, | |
dec_cross_attn_mask: torch.Tensor | None = None, | |
enable_dropout: bool = True, | |
): | |
deterministic = not enable_dropout | |
# --- Encoder Pass --- | |
encoder_out = self.encoder( | |
x_ids=src_BxS, | |
src_positions=src_positions, | |
deterministic=deterministic, | |
attn_mask=enc_self_attn_mask, | |
) | |
# --- Decoder Pass --- | |
logits, _ = self.decoder( | |
tgt_ids_BxTxC=tgt_BxTxC, | |
encoder_out=encoder_out, | |
tgt_positions=tgt_positions, | |
src_positions=src_positions, | |
deterministic=deterministic, | |
self_attn_mask=dec_self_attn_mask, | |
cross_attn_mask=dec_cross_attn_mask, | |
precomputed_cross_attn_kv=None, | |
) | |
return logits | |