|
from typing import Any |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch import Tensor |
|
from torch.nn import RMSNorm |
|
|
|
from .config import DiaConfig |
|
|
|
|
|
def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]: |
|
return tuple(ax if ax >= 0 else ndim + ax for ax in axes) |
|
|
|
|
|
def _str_to_dtype(dtype_str: str) -> torch.dtype | None: |
|
|
|
if dtype_str is None or dtype_str.lower() == "none": |
|
return None |
|
if dtype_str == "float32": |
|
return torch.float32 |
|
elif dtype_str == "float16": |
|
return torch.float16 |
|
elif dtype_str == "bfloat16": |
|
return torch.bfloat16 |
|
else: |
|
raise ValueError(f"Unsupported dtype string: {dtype_str}") |
|
|
|
|
|
class DenseGeneral(nn.Module): |
|
""" |
|
PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init. |
|
|
|
Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot |
|
for the generalized matrix multiplication. Weight/bias shapes are calculated |
|
and parameters created during initialization based on config. |
|
`load_weights` validates shapes and copies data. |
|
|
|
Attributes: |
|
axis (Tuple[int, ...]): Input axis or axes to contract. |
|
in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`. |
|
out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims). |
|
use_bias (bool): Whether to add a bias term. |
|
weight (nn.Parameter): The kernel parameter. |
|
bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True). |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_shapes: tuple[int, ...], |
|
out_features: tuple[int, ...], |
|
axis: tuple[int, ...] = (-1,), |
|
dtype: torch.dtype | None = None, |
|
weight_dtype: torch.dtype | None = None, |
|
device: torch.device | None = None, |
|
): |
|
super().__init__() |
|
self.in_shapes = in_shapes |
|
self.out_features = out_features |
|
self.axis = axis |
|
self.dtype = dtype |
|
self.kernel_shape = self.in_shapes + self.out_features |
|
|
|
factory_kwargs = {"device": device, "dtype": weight_dtype} |
|
self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs)) |
|
self.register_parameter("bias", None) |
|
|
|
def forward(self, inputs: Tensor) -> Tensor: |
|
norm_axis = _normalize_axes(self.axis, inputs.ndim) |
|
kernel_contract_axes = tuple(range(len(norm_axis))) |
|
|
|
output = torch.tensordot( |
|
inputs.float(), |
|
self.weight.float(), |
|
dims=(norm_axis, kernel_contract_axes), |
|
).to(inputs.dtype) |
|
return output |
|
|
|
|
|
def get_activation_fn(activation_string: str) -> nn.Module: |
|
"""Maps activation string to PyTorch activation function module.""" |
|
if activation_string == "gelu": |
|
return nn.GELU() |
|
elif activation_string == "relu": |
|
return nn.ReLU() |
|
elif activation_string == "silu" or activation_string == "swish": |
|
return nn.SiLU() |
|
elif activation_string == "linear": |
|
return nn.Identity() |
|
else: |
|
raise ValueError(f"Unsupported activation function: {activation_string}") |
|
|
|
|
|
class MlpBlock(nn.Module): |
|
"""MLP block using DenseGeneral.""" |
|
|
|
def __init__( |
|
self, |
|
config: DiaConfig, |
|
embed_dim: int, |
|
intermediate_dim: int, |
|
dropout_rate: float, |
|
activations: list[str] = ["silu", "linear"], |
|
use_pre_norm: bool = False, |
|
): |
|
super().__init__() |
|
self.use_pre_norm = use_pre_norm |
|
num_activations = len(activations) |
|
compute_dtype = _str_to_dtype(config.training.dtype) |
|
weight_dtype = _str_to_dtype(config.model.weight_dtype) |
|
self.dtype = compute_dtype |
|
|
|
|
|
if use_pre_norm: |
|
self.pre_norm = RMSNorm( |
|
embed_dim, |
|
eps=config.model.normalization_layer_epsilon, |
|
dtype=torch.float32, |
|
) |
|
|
|
self.wi_fused = DenseGeneral( |
|
in_shapes=(embed_dim,), |
|
out_features=( |
|
num_activations, |
|
intermediate_dim, |
|
), |
|
axis=(-1,), |
|
dtype=compute_dtype, |
|
weight_dtype=weight_dtype, |
|
) |
|
|
|
self.activation_fn_0 = get_activation_fn(activations[0]) |
|
self.activation_fn_1 = get_activation_fn(activations[1]) |
|
|
|
self.dropout = nn.Dropout(dropout_rate) |
|
|
|
|
|
self.wo = DenseGeneral( |
|
in_shapes=(intermediate_dim,), |
|
out_features=(embed_dim,), |
|
axis=(-1,), |
|
dtype=compute_dtype, |
|
weight_dtype=weight_dtype, |
|
) |
|
|
|
def forward(self, x: torch.Tensor, deterministic: bool) -> torch.Tensor: |
|
"""Forward pass.""" |
|
if self.use_pre_norm and hasattr(self, "pre_norm"): |
|
x = self.pre_norm(x) |
|
|
|
fused_x = self.wi_fused(x) |
|
|
|
gate_input = fused_x[..., 0, :] |
|
up_input = fused_x[..., 1, :] |
|
|
|
gate = self.activation_fn_0(gate_input) |
|
up = self.activation_fn_1(up_input) |
|
hidden = torch.mul(gate, up).to(self.dtype) |
|
|
|
if not deterministic: |
|
hidden = self.dropout(hidden) |
|
|
|
output = self.wo(hidden) |
|
return output |
|
|
|
|
|
class RotaryEmbedding(nn.Module): |
|
"""Rotary Position Embedding (RoPE) implementation in PyTorch.""" |
|
|
|
def __init__( |
|
self, |
|
embedding_dims: int, |
|
min_timescale: int = 1, |
|
max_timescale: int = 10000, |
|
dtype: torch.dtype = torch.float32, |
|
): |
|
super().__init__() |
|
if embedding_dims % 2 != 0: |
|
raise ValueError("Embedding dim must be even for RoPE.") |
|
self.embedding_dims = embedding_dims |
|
self.min_timescale = min_timescale |
|
self.max_timescale = max_timescale |
|
self.dtype = dtype |
|
|
|
half_embedding_dim = embedding_dims // 2 |
|
fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims |
|
self.register_buffer( |
|
"timescale", |
|
self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction, |
|
persistent=False, |
|
) |
|
|
|
def extra_repr(self) -> str: |
|
s = f"{self.timescale.shape}" |
|
return s |
|
|
|
def forward(self, inputs: torch.Tensor, position: torch.Tensor): |
|
"""Applies RoPE.""" |
|
position = position.unsqueeze(-1).unsqueeze(-1) |
|
timescale = self.timescale.to(inputs.device) |
|
sinusoid_inp = position / timescale |
|
sin = torch.sin(sinusoid_inp).to(inputs.dtype) |
|
cos = torch.cos(sinusoid_inp).to(inputs.dtype) |
|
first_half, second_half = torch.chunk(inputs, 2, dim=-1) |
|
first_part = first_half * cos - second_half * sin |
|
second_part = second_half * cos + first_half * sin |
|
return torch.cat((first_part, second_part), dim=-1) |
|
|
|
|
|
class KVCache: |
|
def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None): |
|
self.k = torch.zeros((2, num_heads, max_len, head_dim), device=device) if k is None else k |
|
self.v = torch.zeros((2, num_heads, max_len, head_dim), device=device) if v is None else v |
|
self.current_idx = 0 |
|
self.max_len = max_len |
|
|
|
def get_kv_for_attention(self, current_k, current_v): |
|
if self.current_idx == 0: |
|
return current_k, current_v |
|
else: |
|
past_k = self.k[:, :, : self.current_idx, :] |
|
past_v = self.v[:, :, : self.current_idx, :] |
|
attn_k = torch.cat((past_k, current_k), dim=2) |
|
attn_v = torch.cat((past_v, current_v), dim=2) |
|
return attn_k, attn_v |
|
|
|
def update_cache(self, k, v): |
|
assert self.current_idx < self.max_len |
|
self.k[:, :, self.current_idx : self.current_idx + 1, :] = k |
|
self.v[:, :, self.current_idx : self.current_idx + 1, :] = v |
|
self.current_idx += 1 |
|
|
|
def prefill_kv(self, k, v): |
|
prefill_len = k.shape[2] |
|
assert prefill_len <= self.max_len |
|
self.k[:, :, :prefill_len, :] = k |
|
self.v[:, :, :prefill_len, :] = v |
|
self.current_idx = prefill_len |
|
|
|
|
|
class Attention(nn.Module): |
|
"""Attention using DenseGeneral.""" |
|
|
|
def __init__( |
|
self, |
|
config: DiaConfig, |
|
q_embed_dim: int, |
|
kv_embed_dim: int, |
|
num_query_heads: int, |
|
num_kv_heads: int, |
|
head_dim: int, |
|
dropout_rate: float, |
|
is_cross_attn: bool = False, |
|
out_embed_dim: int | None = None, |
|
): |
|
super().__init__() |
|
self.num_query_heads = num_query_heads |
|
self.num_kv_heads = num_kv_heads |
|
self.head_dim = head_dim |
|
self.is_cross_attn = is_cross_attn |
|
self.dropout_rate = dropout_rate |
|
compute_dtype = _str_to_dtype(config.training.dtype) |
|
weight_dtype = _str_to_dtype(config.model.weight_dtype) |
|
self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim |
|
self.projected_query_dim = num_query_heads * head_dim |
|
if num_query_heads % num_kv_heads != 0: |
|
raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})") |
|
self.num_gqa_groups = num_query_heads // num_kv_heads |
|
|
|
|
|
self.q_proj = DenseGeneral( |
|
in_shapes=(q_embed_dim,), |
|
out_features=(num_query_heads, head_dim), |
|
axis=(-1,), |
|
dtype=compute_dtype, |
|
weight_dtype=weight_dtype, |
|
) |
|
self.k_proj = DenseGeneral( |
|
in_shapes=(kv_embed_dim,), |
|
out_features=(num_kv_heads, head_dim), |
|
axis=(-1,), |
|
dtype=compute_dtype, |
|
weight_dtype=weight_dtype, |
|
) |
|
self.v_proj = DenseGeneral( |
|
in_shapes=(kv_embed_dim,), |
|
out_features=(num_kv_heads, head_dim), |
|
axis=(-1,), |
|
dtype=compute_dtype, |
|
weight_dtype=weight_dtype, |
|
) |
|
self.o_proj = DenseGeneral( |
|
in_shapes=(num_query_heads, head_dim), |
|
out_features=(self.output_dim,), |
|
axis=(-2, -1), |
|
dtype=compute_dtype, |
|
weight_dtype=weight_dtype, |
|
) |
|
|
|
|
|
self.rotary_emb = RotaryEmbedding( |
|
embedding_dims=self.head_dim, |
|
min_timescale=config.model.rope_min_timescale, |
|
max_timescale=config.model.rope_max_timescale, |
|
dtype=compute_dtype, |
|
) |
|
|
|
def forward( |
|
self, |
|
Xq: torch.Tensor, |
|
Xkv: torch.Tensor, |
|
q_positions: torch.Tensor, |
|
kv_positions: torch.Tensor | None = None, |
|
deterministic: bool = True, |
|
attn_mask: torch.Tensor | None = None, |
|
cache: KVCache | None = None, |
|
prefill: bool = False, |
|
) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]: |
|
""" |
|
Performs attention calculation with optional KV caching. |
|
|
|
Args: |
|
Xq: Query tensor (B, T, D). T=1 during single-step decoding. |
|
Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn. |
|
q_positions: Positions for queries (B, T). |
|
kv_positions: Positions for keys/values (B, S). If None, uses q_positions. |
|
deterministic: If True, disable dropout. |
|
attn_mask: Attention mask. |
|
cache: KVCache. |
|
prefill: If True, use prefill mode. |
|
|
|
Returns: |
|
A tuple containing: |
|
- output: The attention output tensor (B, T, output_dim). |
|
- present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv. |
|
""" |
|
if kv_positions is None: |
|
kv_positions = q_positions |
|
original_dtype = Xq.dtype |
|
|
|
Xq_BxTxNxH = self.q_proj(Xq) |
|
Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions) |
|
Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2) |
|
|
|
|
|
attn_k: torch.Tensor | None = None |
|
attn_v: torch.Tensor | None = None |
|
new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None |
|
|
|
|
|
if self.is_cross_attn: |
|
|
|
attn_k, attn_v = cache.k, cache.v |
|
if attn_k.shape[1] != self.num_query_heads or attn_v.shape[1] != self.num_query_heads: |
|
raise ValueError( |
|
f"Cross-attention cache head dimension ({attn_k.shape[1]}) " |
|
f"does not match num_query_heads ({self.num_query_heads}). " |
|
"Cache should be pre-repeated for GQA." |
|
) |
|
|
|
else: |
|
Xk_BxSxKxH = self.k_proj(Xkv) |
|
Xv_BxSxKxH = self.v_proj(Xkv) |
|
Xk_BxSxKxH = self.rotary_emb(Xk_BxSxKxH, position=kv_positions) |
|
|
|
Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) |
|
Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) |
|
|
|
|
|
if self.num_gqa_groups > 1: |
|
Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1) |
|
Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1) |
|
else: |
|
Xk_BxNxSxH = Xk_BxKxSxH |
|
Xv_BxNxSxH = Xv_BxKxSxH |
|
|
|
|
|
if cache is None: |
|
attn_k = Xk_BxNxSxH |
|
attn_v = Xv_BxNxSxH |
|
|
|
else: |
|
|
|
if prefill: |
|
attn_k, attn_v = Xk_BxNxSxH, Xv_BxNxSxH |
|
cache.prefill_kv(attn_k, attn_v) |
|
|
|
else: |
|
new_kv_cache = Xk_BxNxSxH, Xv_BxNxSxH |
|
attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH) |
|
|
|
attn_output = F.scaled_dot_product_attention( |
|
Xq_BxNxTxH, |
|
attn_k, |
|
attn_v, |
|
attn_mask=attn_mask, |
|
dropout_p=self.dropout_rate if not deterministic else 0.0, |
|
scale=1.0, |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
output = self.o_proj(attn_output) |
|
|
|
return output.to(original_dtype), new_kv_cache |
|
|
|
|
|
class EncoderLayer(nn.Module): |
|
"""Transformer Encoder Layer using DenseGeneral.""" |
|
|
|
def __init__(self, config: DiaConfig): |
|
super().__init__() |
|
self.config = config |
|
model_config = config.model |
|
enc_config = config.model.encoder |
|
embed_dim = enc_config.n_embd |
|
|
|
self.pre_sa_norm = RMSNorm( |
|
embed_dim, |
|
eps=model_config.normalization_layer_epsilon, |
|
dtype=torch.float32, |
|
) |
|
self.self_attention = Attention( |
|
config=config, |
|
q_embed_dim=embed_dim, |
|
kv_embed_dim=embed_dim, |
|
num_query_heads=enc_config.n_head, |
|
num_kv_heads=enc_config.n_head, |
|
head_dim=enc_config.head_dim, |
|
dropout_rate=model_config.dropout, |
|
is_cross_attn=False, |
|
out_embed_dim=embed_dim, |
|
) |
|
self.post_sa_norm = RMSNorm( |
|
embed_dim, |
|
eps=model_config.normalization_layer_epsilon, |
|
dtype=torch.float32, |
|
) |
|
self.mlp = MlpBlock( |
|
config=config, |
|
embed_dim=embed_dim, |
|
intermediate_dim=enc_config.n_hidden, |
|
activations=enc_config.mlp_activations, |
|
dropout_rate=model_config.dropout, |
|
use_pre_norm=enc_config.use_pre_norm, |
|
) |
|
self.dropout = nn.Dropout(model_config.dropout) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
src_positions: torch.Tensor | None = None, |
|
deterministic: bool = True, |
|
attn_mask: torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
residual = x |
|
x_norm = self.pre_sa_norm(x) |
|
|
|
sa_out, _ = self.self_attention( |
|
Xq=x_norm, |
|
Xkv=x_norm, |
|
q_positions=src_positions, |
|
kv_positions=src_positions, |
|
deterministic=deterministic, |
|
attn_mask=attn_mask, |
|
) |
|
x = residual + sa_out |
|
|
|
residual = x |
|
x_norm = self.post_sa_norm(x) |
|
mlp_out = self.mlp(x_norm, deterministic=deterministic) |
|
x = residual + mlp_out |
|
|
|
if not deterministic: |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
class Encoder(nn.Module): |
|
"""Transformer Encoder Stack using DenseGeneral.""" |
|
|
|
def __init__(self, config: DiaConfig): |
|
super().__init__() |
|
self.config = config |
|
model_config = config.model |
|
enc_config = config.model.encoder |
|
compute_dtype = _str_to_dtype(config.training.dtype) |
|
|
|
self.embedding = nn.Embedding( |
|
model_config.src_vocab_size, |
|
enc_config.n_embd, |
|
dtype=compute_dtype, |
|
) |
|
self.dropout = nn.Dropout(model_config.dropout) |
|
self.layers = nn.ModuleList([EncoderLayer(config=config) for _ in range(enc_config.n_layer)]) |
|
self.norm = RMSNorm( |
|
enc_config.n_embd, |
|
eps=model_config.normalization_layer_epsilon, |
|
dtype=torch.float32, |
|
) |
|
|
|
def forward( |
|
self, |
|
x_ids: torch.Tensor, |
|
src_positions: torch.Tensor | None = None, |
|
deterministic: bool = True, |
|
attn_mask: torch.Tensor | None = None, |
|
) -> torch.Tensor: |
|
x = self.embedding(x_ids) |
|
|
|
if not deterministic: |
|
x = self.dropout(x) |
|
|
|
for layer in self.layers: |
|
x = layer( |
|
x, |
|
src_positions=src_positions, |
|
deterministic=deterministic, |
|
attn_mask=attn_mask, |
|
) |
|
x = self.norm(x) |
|
if not deterministic: |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
class DecoderLayer(nn.Module): |
|
"""Transformer Decoder Layer using DenseGeneral.""" |
|
|
|
def __init__(self, config: DiaConfig): |
|
super().__init__() |
|
self.config = config |
|
model_config = config.model |
|
dec_config = config.model.decoder |
|
enc_config = config.model.encoder |
|
dec_embed_dim = dec_config.n_embd |
|
enc_embed_dim = enc_config.n_embd |
|
|
|
|
|
self.pre_sa_norm = RMSNorm( |
|
dec_embed_dim, |
|
eps=model_config.normalization_layer_epsilon, |
|
dtype=torch.float32, |
|
) |
|
self.pre_ca_norm = RMSNorm( |
|
dec_embed_dim, |
|
eps=model_config.normalization_layer_epsilon, |
|
dtype=torch.float32, |
|
) |
|
self.pre_mlp_norm = RMSNorm( |
|
dec_embed_dim, |
|
eps=model_config.normalization_layer_epsilon, |
|
dtype=torch.float32, |
|
) |
|
|
|
|
|
self.self_attention = Attention( |
|
config=config, |
|
q_embed_dim=dec_embed_dim, |
|
kv_embed_dim=dec_embed_dim, |
|
num_query_heads=dec_config.gqa_query_heads, |
|
num_kv_heads=dec_config.kv_heads, |
|
head_dim=dec_config.gqa_head_dim, |
|
dropout_rate=model_config.dropout, |
|
is_cross_attn=False, |
|
out_embed_dim=dec_embed_dim, |
|
) |
|
|
|
self.cross_attention = Attention( |
|
config=config, |
|
q_embed_dim=dec_embed_dim, |
|
kv_embed_dim=enc_embed_dim, |
|
num_query_heads=dec_config.cross_query_heads, |
|
num_kv_heads=dec_config.cross_query_heads, |
|
head_dim=dec_config.cross_head_dim, |
|
dropout_rate=model_config.dropout, |
|
is_cross_attn=True, |
|
out_embed_dim=dec_embed_dim, |
|
) |
|
|
|
self.mlp = MlpBlock( |
|
config=config, |
|
embed_dim=dec_embed_dim, |
|
intermediate_dim=dec_config.n_hidden, |
|
activations=dec_config.mlp_activations, |
|
dropout_rate=model_config.dropout, |
|
use_pre_norm=dec_config.use_pre_norm, |
|
) |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
encoder_out: torch.Tensor, |
|
tgt_positions: torch.Tensor, |
|
src_positions: torch.Tensor | None, |
|
deterministic: bool, |
|
self_attn_mask: torch.Tensor, |
|
cross_attn_mask: torch.Tensor, |
|
self_attn_cache: KVCache, |
|
cross_attn_cache: KVCache, |
|
prefill: bool = False, |
|
) -> torch.Tensor: |
|
residual = x |
|
x_norm = self.pre_sa_norm(x) |
|
|
|
sa_out, new_kv_cache = self.self_attention( |
|
Xq=x_norm, |
|
Xkv=x_norm, |
|
q_positions=tgt_positions, |
|
kv_positions=tgt_positions, |
|
deterministic=deterministic, |
|
attn_mask=self_attn_mask, |
|
cache=self_attn_cache, |
|
prefill=prefill, |
|
) |
|
|
|
x = residual + sa_out |
|
|
|
|
|
residual = x |
|
x_norm = self.pre_ca_norm(x) |
|
ca_out, _ = self.cross_attention( |
|
Xq=x_norm, |
|
Xkv=encoder_out, |
|
q_positions=tgt_positions, |
|
kv_positions=src_positions, |
|
deterministic=deterministic, |
|
attn_mask=cross_attn_mask, |
|
cache=cross_attn_cache, |
|
) |
|
x = residual + ca_out |
|
|
|
|
|
residual = x |
|
x_norm = self.pre_mlp_norm(x) |
|
mlp_out = self.mlp(x_norm, deterministic=deterministic) |
|
x = residual + mlp_out |
|
|
|
return x, new_kv_cache |
|
|
|
|
|
class Decoder(nn.Module): |
|
"""Transformer Decoder Stack using DenseGeneral.""" |
|
|
|
def __init__(self, config: DiaConfig): |
|
super().__init__() |
|
self.config = config |
|
model_config = config.model |
|
dec_config = config.model.decoder |
|
train_config = config.training |
|
data_config = config.data |
|
compute_dtype = _str_to_dtype(config.training.dtype) |
|
weight_dtype = _str_to_dtype(config.model.weight_dtype) |
|
self.num_channels = data_config.channels |
|
self.num_layers = dec_config.n_layer |
|
|
|
self.embeddings = nn.ModuleList( |
|
[ |
|
nn.Embedding(model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype) |
|
for _ in range(self.num_channels) |
|
] |
|
) |
|
self.dropout = nn.Dropout(model_config.dropout) |
|
self.layers = nn.ModuleList([DecoderLayer(config=config) for _ in range(self.num_layers)]) |
|
self.norm = RMSNorm( |
|
dec_config.n_embd, |
|
eps=model_config.normalization_layer_epsilon, |
|
dtype=torch.float32, |
|
) |
|
|
|
|
|
self.logits_dense = DenseGeneral( |
|
in_shapes=(dec_config.n_embd,), |
|
out_features=(self.num_channels, model_config.tgt_vocab_size), |
|
axis=(-1,), |
|
dtype=(torch.float32 if train_config.logits_dot_in_fp32 else compute_dtype), |
|
weight_dtype=weight_dtype, |
|
) |
|
self.logits_in_fp32 = train_config.logits_dot_in_fp32 |
|
|
|
def precompute_cross_attention_kv( |
|
self, |
|
max_len: int, |
|
encoder_out: torch.Tensor, |
|
src_positions: torch.Tensor | None, |
|
) -> list[KVCache]: |
|
""" |
|
Computes the Key and Value tensors for cross-attention for each layer from the encoder output. |
|
""" |
|
per_layer_kv_cache: list[KVCache] = [] |
|
|
|
for layer in self.layers: |
|
cross_attn_module = layer.cross_attention |
|
k_proj = cross_attn_module.k_proj(encoder_out) |
|
v_proj = cross_attn_module.v_proj(encoder_out) |
|
|
|
k_proj = cross_attn_module.rotary_emb(k_proj, position=src_positions) |
|
k = k_proj.transpose(1, 2) |
|
v = v_proj.transpose(1, 2) |
|
|
|
per_layer_kv_cache.append( |
|
KVCache( |
|
cross_attn_module.num_kv_heads, |
|
max_len, |
|
cross_attn_module.head_dim, |
|
k.device, |
|
k=k, |
|
v=v, |
|
) |
|
) |
|
|
|
return per_layer_kv_cache |
|
|
|
def decode_step( |
|
self, |
|
tgt_ids_Bx1xC: torch.Tensor, |
|
tgt_pos_Bx1: torch.Tensor, |
|
encoder_out: torch.Tensor, |
|
self_attn_mask: Any, |
|
cross_attn_mask: torch.Tensor, |
|
self_attention_cache: list[KVCache], |
|
cross_attention_cache: list[KVCache], |
|
) -> torch.Tensor: |
|
""" |
|
Performs a single decoding step, managing KV caches layer by layer. |
|
|
|
Returns: |
|
A tuple containing: |
|
- logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32. |
|
""" |
|
assert self_attn_mask is None, "Self-attention mask should be None, kept for pattern" |
|
|
|
x = None |
|
for i in range(self.num_channels): |
|
channel_tokens = tgt_ids_Bx1xC[..., i] |
|
channel_embed = self.embeddings[i](channel_tokens) |
|
x = channel_embed if x is None else x + channel_embed |
|
|
|
new_cache = [] |
|
|
|
for i, layer in enumerate(self.layers): |
|
self_cache = self_attention_cache[i] |
|
cross_cache = cross_attention_cache[i] |
|
x, new_kv_cache = layer( |
|
x, |
|
encoder_out, |
|
src_positions=None, |
|
tgt_positions=tgt_pos_Bx1, |
|
deterministic=True, |
|
self_attn_mask=None, |
|
cross_attn_mask=cross_attn_mask, |
|
self_attn_cache=self_cache, |
|
cross_attn_cache=cross_cache, |
|
) |
|
new_cache.append(new_kv_cache) |
|
|
|
x = self.norm(x) |
|
logits_Bx1xCxV = self.logits_dense(x) |
|
|
|
return logits_Bx1xCxV.to(torch.float32), new_cache |
|
|
|
def forward( |
|
self, |
|
tgt_ids_BxTxC: torch.Tensor, |
|
encoder_out: torch.Tensor, |
|
tgt_positions: torch.Tensor, |
|
src_positions: torch.Tensor, |
|
deterministic: bool, |
|
self_attn_mask: torch.Tensor, |
|
cross_attn_mask: torch.Tensor, |
|
self_attention_cache: list[KVCache], |
|
cross_attention_cache: list[KVCache], |
|
) -> torch.Tensor: |
|
""" |
|
Forward pass for the Decoder stack, managing KV caches. |
|
|
|
Args: |
|
tgt_ids_BxTxC: Target token IDs (B, T, C). |
|
encoder_out: Output from the encoder (B, S, E). |
|
tgt_positions: Positions for target sequence (B, T). |
|
src_positions: Positions for source sequence (B, S). |
|
deterministic: Disable dropout if True. |
|
self_attn_mask: Mask for self-attention. |
|
cross_attn_mask: Mask for cross-attention. |
|
past_key_values: List containing the self-attention KV cache for each layer |
|
from the previous decoding step. `len(past_key_values)` should |
|
equal `num_layers`. |
|
precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache |
|
derived from `encoder_out`. This is passed identically |
|
to all layers. |
|
|
|
Returns: |
|
A tuple containing: |
|
- logits: The final output logits (B, T, C * V), cast to float32. |
|
- present_key_values: A list containing the updated self-attention KV cache |
|
for each layer for the *current* decoding step. |
|
""" |
|
_, _, num_channels_in = tgt_ids_BxTxC.shape |
|
assert num_channels_in == self.num_channels, "Input channels mismatch" |
|
|
|
|
|
x = None |
|
for i in range(self.num_channels): |
|
channel_tokens = tgt_ids_BxTxC[..., i] |
|
channel_embed = self.embeddings[i](channel_tokens) |
|
x = channel_embed if x is None else x + channel_embed |
|
|
|
if not deterministic: |
|
x = self.dropout(x) |
|
|
|
for i, layer in enumerate(self.layers): |
|
x, _ = layer( |
|
x, |
|
encoder_out, |
|
tgt_positions=tgt_positions, |
|
src_positions=src_positions, |
|
deterministic=deterministic, |
|
self_attn_mask=self_attn_mask, |
|
cross_attn_mask=cross_attn_mask, |
|
self_attn_cache=self_attention_cache[i], |
|
cross_attn_cache=cross_attention_cache[i], |
|
prefill=True, |
|
) |
|
|
|
|
|
x = self.norm(x) |
|
logits_BxTxCxV = self.logits_dense(x) |
|
|
|
return logits_BxTxCxV.to(torch.float32) |
|
|
|
|
|
class DiaModel(nn.Module): |
|
"""PyTorch Dia Model using DenseGeneral.""" |
|
|
|
def __init__(self, config: DiaConfig): |
|
super().__init__() |
|
self.config = config |
|
self.encoder = Encoder(config) |
|
self.decoder = Decoder(config) |
|
|
|
def forward( |
|
self, |
|
src_BxS: torch.Tensor, |
|
tgt_BxTxC: torch.Tensor, |
|
src_positions: torch.Tensor | None = None, |
|
tgt_positions: torch.Tensor | None = None, |
|
enc_self_attn_mask: torch.Tensor | None = None, |
|
dec_self_attn_mask: torch.Tensor | None = None, |
|
dec_cross_attn_mask: torch.Tensor | None = None, |
|
enable_dropout: bool = True, |
|
): |
|
deterministic = not enable_dropout |
|
|
|
|
|
encoder_out = self.encoder( |
|
x_ids=src_BxS, |
|
src_positions=src_positions, |
|
deterministic=deterministic, |
|
attn_mask=enc_self_attn_mask, |
|
) |
|
|
|
|
|
logits, _ = self.decoder( |
|
tgt_ids_BxTxC=tgt_BxTxC, |
|
encoder_out=encoder_out, |
|
tgt_positions=tgt_positions, |
|
src_positions=src_positions, |
|
deterministic=deterministic, |
|
self_attn_mask=dec_self_attn_mask, |
|
cross_attn_mask=dec_cross_attn_mask, |
|
precomputed_cross_attn_kv=None, |
|
) |
|
|
|
return logits |
|
|