Michael Hu commited on
Commit
9c4b958
·
1 Parent(s): 1a3633a

add dia tts model. Since dia is not yet released to pypi, we pull in the source directly

Browse files
Files changed (6) hide show
  1. dia/__init__.py +6 -0
  2. dia/audio.py +185 -0
  3. dia/config.py +187 -0
  4. dia/layers.py +624 -0
  5. dia/model.py +455 -0
  6. dia/state.py +207 -0
dia/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from .model import Dia
2
+
3
+
4
+ __all__ = [
5
+ "Dia",
6
+ ]
dia/audio.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing as tp
2
+
3
+ import torch
4
+
5
+
6
+ def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
7
+ """
8
+ Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
9
+ Negative t_idx => BOS; t_idx >= T => PAD.
10
+ """
11
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32)
12
+
13
+ t_idx_BxT = torch.broadcast_to(
14
+ torch.arange(T, dtype=torch.int32)[None, :],
15
+ [B, T],
16
+ )
17
+ t_idx_BxTx1 = t_idx_BxT[..., None]
18
+ t_idx_BxTxC = t_idx_BxTx1 - delay_arr.view(1, 1, C)
19
+
20
+ b_idx_BxTxC = torch.broadcast_to(
21
+ torch.arange(B, dtype=torch.int32).view(B, 1, 1),
22
+ [B, T, C],
23
+ )
24
+ c_idx_BxTxC = torch.broadcast_to(
25
+ torch.arange(C, dtype=torch.int32).view(1, 1, C),
26
+ [B, T, C],
27
+ )
28
+
29
+ # We must clamp time indices to [0..T-1] so gather_nd equivalent won't fail
30
+ t_clamped_BxTxC = torch.clamp(t_idx_BxTxC, 0, T - 1)
31
+
32
+ indices_BTCx3 = torch.stack(
33
+ [
34
+ b_idx_BxTxC.reshape(-1),
35
+ t_clamped_BxTxC.reshape(-1),
36
+ c_idx_BxTxC.reshape(-1),
37
+ ],
38
+ dim=1,
39
+ ).long() # Ensure indices are long type for indexing
40
+
41
+ return t_idx_BxTxC, indices_BTCx3
42
+
43
+
44
+ def apply_audio_delay(
45
+ audio_BxTxC: torch.Tensor,
46
+ pad_value: int,
47
+ bos_value: int,
48
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
49
+ ) -> torch.Tensor:
50
+ """
51
+ Applies the delay pattern to batched audio tokens using precomputed indices,
52
+ inserting BOS where t_idx < 0 and PAD where t_idx >= T.
53
+
54
+ Args:
55
+ audio_BxTxC: [B, T, C] int16 audio tokens (or int32/float)
56
+ pad_value: the padding token
57
+ bos_value: the BOS token
58
+ precomp: (t_idx_BxTxC, indices_BTCx3) from build_delay_indices
59
+
60
+ Returns:
61
+ result_BxTxC: [B, T, C] delayed audio tokens
62
+ """
63
+ device = audio_BxTxC.device # Get device from input tensor
64
+ t_idx_BxTxC, indices_BTCx3 = precomp
65
+ t_idx_BxTxC = t_idx_BxTxC.to(device) # Move precomputed indices to device
66
+ indices_BTCx3 = indices_BTCx3.to(device)
67
+
68
+ # Equivalent of tf.gather_nd using advanced indexing
69
+ # Ensure indices are long type if not already (build_delay_indices should handle this)
70
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
71
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
72
+
73
+ # Create masks on the correct device
74
+ mask_bos = t_idx_BxTxC < 0 # => place bos_value
75
+ mask_pad = t_idx_BxTxC >= audio_BxTxC.shape[1] # => place pad_value
76
+
77
+ # Create scalar tensors on the correct device
78
+ bos_tensor = torch.tensor(bos_value, dtype=audio_BxTxC.dtype, device=device)
79
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
80
+
81
+ # If mask_bos, BOS; else if mask_pad, PAD; else original gather
82
+ # All tensors should now be on the same device
83
+ result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
84
+
85
+ return result_BxTxC
86
+
87
+
88
+ def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
89
+ """
90
+ Precompute indices for the revert operation using PyTorch.
91
+
92
+ Returns:
93
+ A tuple (t_idx_BxTxC, indices_BTCx3) where:
94
+ - t_idx_BxTxC is a tensor of shape [B, T, C] computed as time indices plus the delay.
95
+ - indices_BTCx3 is a tensor of shape [B*T*C, 3] used for gathering, computed from:
96
+ batch indices, clamped time indices, and channel indices.
97
+ """
98
+ # Use default device unless specified otherwise; assumes inputs might define device later
99
+ device = None # Or determine dynamically if needed, e.g., from a model parameter
100
+
101
+ delay_arr = torch.tensor(delay_pattern, dtype=torch.int32, device=device)
102
+
103
+ t_idx_BT1 = torch.broadcast_to(torch.arange(T, device=device).unsqueeze(0), [B, T])
104
+ t_idx_BT1 = t_idx_BT1.unsqueeze(-1)
105
+
106
+ t_idx_BxTxC = torch.minimum(
107
+ t_idx_BT1 + delay_arr.view(1, 1, C),
108
+ torch.tensor(T - 1, device=device),
109
+ )
110
+ b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
111
+ c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
112
+
113
+ indices_BTCx3 = torch.stack(
114
+ [
115
+ b_idx_BxTxC.reshape(-1),
116
+ t_idx_BxTxC.reshape(-1),
117
+ c_idx_BxTxC.reshape(-1),
118
+ ],
119
+ axis=1,
120
+ ).long() # Ensure indices are long type
121
+
122
+ return t_idx_BxTxC, indices_BTCx3
123
+
124
+
125
+ def revert_audio_delay(
126
+ audio_BxTxC: torch.Tensor,
127
+ pad_value: int,
128
+ precomp: tp.Tuple[torch.Tensor, torch.Tensor],
129
+ T: int,
130
+ ) -> torch.Tensor:
131
+ """
132
+ Reverts a delay pattern from batched audio tokens using precomputed indices (PyTorch version).
133
+
134
+ Args:
135
+ audio_BxTxC: Input delayed audio tensor
136
+ pad_value: Padding value for out-of-bounds indices
137
+ precomp: Precomputed revert indices tuple containing:
138
+ - t_idx_BxTxC: Time offset indices tensor
139
+ - indices_BTCx3: Gather indices tensor for original audio
140
+ T: Original sequence length before padding
141
+
142
+ Returns:
143
+ Reverted audio tensor with same shape as input
144
+ """
145
+ t_idx_BxTxC, indices_BTCx3 = precomp
146
+ device = audio_BxTxC.device # Get device from input tensor
147
+
148
+ # Move precomputed indices to the same device as audio_BxTxC if they aren't already
149
+ t_idx_BxTxC = t_idx_BxTxC.to(device)
150
+ indices_BTCx3 = indices_BTCx3.to(device)
151
+
152
+ # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
153
+ gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
154
+ gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
155
+
156
+ # Create pad_tensor on the correct device
157
+ pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
158
+ # Create T tensor on the correct device for comparison
159
+ T_tensor = torch.tensor(T, device=device)
160
+
161
+ result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
162
+
163
+ return result_BxTxC
164
+
165
+
166
+ @torch.no_grad()
167
+ @torch.inference_mode()
168
+ def decode(
169
+ model,
170
+ audio_codes,
171
+ ):
172
+ """
173
+ Decodes the given frames into an output audio waveform
174
+ """
175
+ if len(audio_codes) != 1:
176
+ raise ValueError(f"Expected one frame, got {len(audio_codes)}")
177
+
178
+ try:
179
+ audio_values = model.quantizer.from_codes(audio_codes)
180
+ audio_values = model.decode(audio_values[0])
181
+
182
+ return audio_values
183
+ except Exception as e:
184
+ print(f"Error in decode method: {str(e)}")
185
+ raise
dia/config.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration management module for the Dia model.
2
+
3
+ This module provides comprehensive configuration management for the Dia model,
4
+ utilizing Pydantic for validation. It defines configurations for data processing,
5
+ model architecture (encoder and decoder), and training settings.
6
+
7
+ Key components:
8
+ - DataConfig: Parameters for data loading and preprocessing.
9
+ - EncoderConfig: Architecture details for the encoder module.
10
+ - DecoderConfig: Architecture details for the decoder module.
11
+ - ModelConfig: Combined model architecture settings.
12
+ - TrainingConfig: Training hyperparameters and settings.
13
+ - DiaConfig: Master configuration combining all components.
14
+ """
15
+
16
+ import os
17
+ from typing import Annotated
18
+
19
+ from pydantic import BaseModel, BeforeValidator, Field
20
+
21
+
22
+ class DataConfig(BaseModel, frozen=True):
23
+ """Configuration for data loading and preprocessing.
24
+
25
+ Attributes:
26
+ text_length: Maximum length of text sequences (must be multiple of 128).
27
+ audio_length: Maximum length of audio sequences (must be multiple of 128).
28
+ channels: Number of audio channels.
29
+ text_pad_value: Value used for padding text sequences.
30
+ audio_eos_value: Value representing the end of audio sequences.
31
+ audio_bos_value: Value representing the beginning of audio sequences.
32
+ audio_pad_value: Value used for padding audio sequences.
33
+ delay_pattern: List of delay values for each audio channel.
34
+ """
35
+
36
+ text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
37
+ audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
38
+ channels: int = Field(default=9, gt=0, multiple_of=1)
39
+ text_pad_value: int = Field(default=0)
40
+ audio_eos_value: int = Field(default=1024)
41
+ audio_pad_value: int = Field(default=1025)
42
+ audio_bos_value: int = Field(default=1026)
43
+ delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15])
44
+
45
+ def __hash__(self) -> int:
46
+ """Generate a hash based on all fields of the config."""
47
+ return hash(
48
+ (
49
+ self.text_length,
50
+ self.audio_length,
51
+ self.channels,
52
+ self.text_pad_value,
53
+ self.audio_pad_value,
54
+ self.audio_bos_value,
55
+ self.audio_eos_value,
56
+ tuple(self.delay_pattern),
57
+ )
58
+ )
59
+
60
+
61
+ class EncoderConfig(BaseModel, frozen=True):
62
+ """Configuration for the encoder component of the Dia model.
63
+
64
+ Attributes:
65
+ n_layer: Number of transformer layers.
66
+ n_embd: Embedding dimension.
67
+ n_hidden: Hidden dimension size in the MLP layers.
68
+ n_head: Number of attention heads.
69
+ head_dim: Dimension per attention head.
70
+ """
71
+
72
+ n_layer: int = Field(gt=0)
73
+ n_embd: int = Field(gt=0)
74
+ n_hidden: int = Field(gt=0)
75
+ n_head: int = Field(gt=0)
76
+ head_dim: int = Field(gt=0)
77
+
78
+
79
+ class DecoderConfig(BaseModel, frozen=True):
80
+ """Configuration for the decoder component of the Dia model.
81
+
82
+ Attributes:
83
+ n_layer: Number of transformer layers.
84
+ n_embd: Embedding dimension.
85
+ n_hidden: Hidden dimension size in the MLP layers.
86
+ gqa_query_heads: Number of query heads for grouped-query self-attention.
87
+ kv_heads: Number of key/value heads for grouped-query self-attention.
88
+ gqa_head_dim: Dimension per query head for grouped-query self-attention.
89
+ cross_query_heads: Number of query heads for cross-attention.
90
+ cross_head_dim: Dimension per cross-attention head.
91
+ """
92
+
93
+ n_layer: int = Field(gt=0)
94
+ n_embd: int = Field(gt=0)
95
+ n_hidden: int = Field(gt=0)
96
+ gqa_query_heads: int = Field(gt=0)
97
+ kv_heads: int = Field(gt=0)
98
+ gqa_head_dim: int = Field(gt=0)
99
+ cross_query_heads: int = Field(gt=0)
100
+ cross_head_dim: int = Field(gt=0)
101
+
102
+
103
+ class ModelConfig(BaseModel, frozen=True):
104
+ """Main configuration container for the Dia model architecture.
105
+
106
+ Attributes:
107
+ encoder: Configuration for the encoder component.
108
+ decoder: Configuration for the decoder component.
109
+ src_vocab_size: Size of the source (text) vocabulary.
110
+ tgt_vocab_size: Size of the target (audio code) vocabulary.
111
+ dropout: Dropout probability applied within the model.
112
+ normalization_layer_epsilon: Epsilon value for normalization layers (e.g., LayerNorm).
113
+ weight_dtype: Data type for model weights (e.g., "float32", "bfloat16").
114
+ rope_min_timescale: Minimum timescale for Rotary Positional Embeddings (RoPE).
115
+ rope_max_timescale: Maximum timescale for Rotary Positional Embeddings (RoPE).
116
+ """
117
+
118
+ encoder: EncoderConfig
119
+ decoder: DecoderConfig
120
+ src_vocab_size: int = Field(default=128, gt=0)
121
+ tgt_vocab_size: int = Field(default=1028, gt=0)
122
+ dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
123
+ normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
124
+ weight_dtype: str = Field(default="float32", description="Weight precision")
125
+ rope_min_timescale: int = Field(default=1, description="Timescale For global Attention")
126
+ rope_max_timescale: int = Field(default=10_000, description="Timescale For global Attention")
127
+
128
+
129
+ class TrainingConfig(BaseModel, frozen=True):
130
+ pass
131
+
132
+
133
+ class DiaConfig(BaseModel, frozen=True):
134
+ """Master configuration for the Dia model.
135
+
136
+ Combines all sub-configurations into a single validated object.
137
+
138
+ Attributes:
139
+ version: Configuration version string.
140
+ model: Model architecture configuration.
141
+ training: Training process configuration (precision settings).
142
+ data: Data loading and processing configuration.
143
+ """
144
+
145
+ version: str = Field(default="1.0")
146
+ model: ModelConfig
147
+ # TODO: remove training. this is just for backward compatibility
148
+ training: TrainingConfig | None = Field(default=None)
149
+ data: DataConfig
150
+
151
+ def save(self, path: str) -> None:
152
+ """Save the current configuration instance to a JSON file.
153
+
154
+ Ensures the parent directory exists and the file has a .json extension.
155
+
156
+ Args:
157
+ path: The target file path to save the configuration.
158
+
159
+ Raises:
160
+ ValueError: If the path is not a file with a .json extension.
161
+ """
162
+ os.makedirs(os.path.dirname(path), exist_ok=True)
163
+ config_json = self.model_dump_json(indent=2)
164
+ with open(path, "w") as f:
165
+ f.write(config_json)
166
+
167
+ @classmethod
168
+ def load(cls, path: str) -> "DiaConfig | None":
169
+ """Load and validate a Dia configuration from a JSON file.
170
+
171
+ Args:
172
+ path: The path to the configuration file.
173
+
174
+ Returns:
175
+ A validated DiaConfig instance if the file exists and is valid,
176
+ otherwise None if the file is not found.
177
+
178
+ Raises:
179
+ ValueError: If the path does not point to an existing .json file.
180
+ pydantic.ValidationError: If the JSON content fails validation against the DiaConfig schema.
181
+ """
182
+ try:
183
+ with open(path, "r") as f:
184
+ content = f.read()
185
+ return cls.model_validate_json(content)
186
+ except FileNotFoundError:
187
+ return None
dia/layers.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from huggingface_hub import PyTorchModelHubMixin
5
+ from torch import Tensor
6
+ from torch.nn import RMSNorm
7
+
8
+ from .config import DiaConfig
9
+ from .state import DecoderInferenceState, EncoderInferenceState, KVCache
10
+
11
+
12
+ def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
13
+ return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
14
+
15
+
16
+ class DenseGeneral(nn.Module):
17
+ """
18
+ PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
19
+
20
+ Stores weights (`kernel`) in the same layout as Jax and uses torch.tensordot
21
+ for the generalized matrix multiplication. Weight/bias shapes are calculated
22
+ and parameters created during initialization based on config.
23
+ `load_weights` validates shapes and copies data.
24
+
25
+ Attributes:
26
+ axis (Tuple[int, ...]): Input axis or axes to contract.
27
+ in_shapes (Tuple[int, ...]): Sizes of the input dimensions specified by `axis`.
28
+ out_features (Tuple[int, ...]): Shape of the output features (non-contracted dims).
29
+ use_bias (bool): Whether to add a bias term.
30
+ weight (nn.Parameter): The kernel parameter.
31
+ bias (Optional[nn.Parameter]): The bias parameter (if use_bias=True).
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ in_shapes: tuple[int, ...],
37
+ out_features: tuple[int, ...],
38
+ axis: tuple[int, ...] = (-1,),
39
+ weight_dtype: torch.dtype | None = None,
40
+ device: torch.device | None = None,
41
+ ):
42
+ super().__init__()
43
+ self.in_shapes = in_shapes
44
+ self.out_features = out_features
45
+ self.axis = axis
46
+ self.kernel_shape = self.in_shapes + self.out_features
47
+
48
+ factory_kwargs = {"device": device, "dtype": weight_dtype}
49
+ self.weight = nn.Parameter(torch.empty(self.kernel_shape, **factory_kwargs))
50
+
51
+ def forward(self, inputs: Tensor) -> Tensor:
52
+ norm_axis = _normalize_axes(self.axis, inputs.ndim)
53
+ kernel_contract_axes = tuple(range(len(norm_axis)))
54
+
55
+ output = torch.tensordot(
56
+ inputs.to(self.weight.dtype),
57
+ self.weight,
58
+ dims=(norm_axis, kernel_contract_axes),
59
+ ).to(inputs.dtype)
60
+ return output
61
+
62
+
63
+ class MlpBlock(nn.Module):
64
+ """MLP block using DenseGeneral."""
65
+
66
+ def __init__(self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype):
67
+ super().__init__()
68
+ self.dtype = compute_dtype
69
+
70
+ self.wi_fused = DenseGeneral(
71
+ in_shapes=(embed_dim,),
72
+ out_features=(2, intermediate_dim),
73
+ axis=(-1,),
74
+ weight_dtype=compute_dtype,
75
+ )
76
+
77
+ self.wo = DenseGeneral(
78
+ in_shapes=(intermediate_dim,),
79
+ out_features=(embed_dim,),
80
+ axis=(-1,),
81
+ weight_dtype=compute_dtype,
82
+ )
83
+
84
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
85
+ """Forward pass."""
86
+ fused_x = self.wi_fused(x)
87
+
88
+ gate = fused_x[..., 0, :]
89
+ up = fused_x[..., 1, :]
90
+
91
+ hidden = torch.mul(F.silu(gate), up).to(self.dtype)
92
+
93
+ output = self.wo(hidden)
94
+ return output
95
+
96
+
97
+ class RotaryEmbedding(nn.Module):
98
+ """Rotary Position Embedding (RoPE) implementation in PyTorch."""
99
+
100
+ def __init__(
101
+ self,
102
+ embedding_dims: int,
103
+ min_timescale: int = 1,
104
+ max_timescale: int = 10000,
105
+ dtype: torch.dtype = torch.float32,
106
+ ):
107
+ super().__init__()
108
+ if embedding_dims % 2 != 0:
109
+ raise ValueError("Embedding dim must be even for RoPE.")
110
+ self.embedding_dims = embedding_dims
111
+ self.min_timescale = min_timescale
112
+ self.max_timescale = max_timescale
113
+ self.compute_dtype = dtype
114
+
115
+ half_embedding_dim = embedding_dims // 2
116
+ fraction = (2.0 * torch.arange(0, half_embedding_dim)) / embedding_dims
117
+ timescale = (self.min_timescale * (self.max_timescale / self.min_timescale) ** fraction).to(torch.float32)
118
+ self.register_buffer("timescale", timescale, persistent=False)
119
+
120
+ def forward(self, inputs: torch.Tensor, position: torch.Tensor):
121
+ """Applies RoPE."""
122
+ position = position.unsqueeze(-1).unsqueeze(-1)
123
+ sinusoid_inp = position / self.timescale
124
+ sin = torch.sin(sinusoid_inp)
125
+ cos = torch.cos(sinusoid_inp)
126
+ first_half, second_half = torch.chunk(inputs.to(torch.float32), 2, dim=-1)
127
+ first_part = first_half * cos - second_half * sin
128
+ second_part = second_half * cos + first_half * sin
129
+ return torch.cat((first_part.to(self.compute_dtype), second_part.to(self.compute_dtype)), dim=-1)
130
+
131
+
132
+ class Attention(nn.Module):
133
+ """Attention using DenseGeneral."""
134
+
135
+ def __init__(
136
+ self,
137
+ config: DiaConfig,
138
+ q_embed_dim: int,
139
+ kv_embed_dim: int,
140
+ num_query_heads: int,
141
+ num_kv_heads: int,
142
+ head_dim: int,
143
+ compute_dtype: torch.dtype,
144
+ is_cross_attn: bool = False,
145
+ out_embed_dim: int | None = None,
146
+ ):
147
+ super().__init__()
148
+ self.num_query_heads = num_query_heads
149
+ self.num_kv_heads = num_kv_heads
150
+ self.head_dim = head_dim
151
+ self.is_cross_attn = is_cross_attn
152
+ self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
153
+ self.projected_query_dim = num_query_heads * head_dim
154
+ if num_query_heads % num_kv_heads != 0:
155
+ raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
156
+ self.num_gqa_groups = num_query_heads // num_kv_heads
157
+
158
+ # --- Projection Layers using DenseGeneral ---
159
+ self.q_proj = DenseGeneral(
160
+ in_shapes=(q_embed_dim,),
161
+ out_features=(num_query_heads, head_dim),
162
+ axis=(-1,),
163
+ weight_dtype=compute_dtype,
164
+ )
165
+ self.k_proj = DenseGeneral(
166
+ in_shapes=(kv_embed_dim,),
167
+ out_features=(num_kv_heads, head_dim),
168
+ axis=(-1,),
169
+ weight_dtype=compute_dtype,
170
+ )
171
+ self.v_proj = DenseGeneral(
172
+ in_shapes=(kv_embed_dim,),
173
+ out_features=(num_kv_heads, head_dim),
174
+ axis=(-1,),
175
+ weight_dtype=compute_dtype,
176
+ )
177
+ self.o_proj = DenseGeneral(
178
+ in_shapes=(num_query_heads, head_dim),
179
+ out_features=(self.output_dim,),
180
+ axis=(-2, -1),
181
+ weight_dtype=compute_dtype,
182
+ )
183
+
184
+ # --- Rotary Embedding ---
185
+ self.rotary_emb = RotaryEmbedding(
186
+ embedding_dims=self.head_dim,
187
+ min_timescale=config.model.rope_min_timescale,
188
+ max_timescale=config.model.rope_max_timescale,
189
+ dtype=compute_dtype,
190
+ )
191
+
192
+ def forward(
193
+ self,
194
+ Xq: torch.Tensor, # (B, T, D) T = 1 in AR generation
195
+ Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
196
+ q_positions: torch.Tensor, # (B, T)
197
+ kv_positions: torch.Tensor | None = None, # (B, S)
198
+ attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
199
+ cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
200
+ prefill: bool = False,
201
+ is_causal: bool = False,
202
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
203
+ """
204
+ Performs attention calculation with optional KV caching.
205
+
206
+ Args:
207
+ Xq: Query tensor (B, T, D). T=1 during single-step decoding.
208
+ Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
209
+ q_positions: Positions for queries (B, T).
210
+ kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
211
+ attn_mask: Attention mask.
212
+ cache: KVCache.
213
+ prefill: If True, use prefill mode.
214
+
215
+ Returns:
216
+ A tuple containing:
217
+ - output: The attention output tensor (B, T, output_dim).
218
+ - present_kv: The K/V state to be cached for the next step ((B, N, S_new, H), (B, N, S_new, H)). For self-attn, S_new = S_past + S. For cross-attn, S_new = S_kv.
219
+ """
220
+ if kv_positions is None:
221
+ kv_positions = q_positions
222
+ original_dtype = Xq.dtype
223
+
224
+ Xq_BxTxNxH = self.q_proj(Xq)
225
+ Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
226
+ Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
227
+
228
+ attn_k: torch.Tensor | None = None
229
+ attn_v: torch.Tensor | None = None
230
+
231
+ if self.is_cross_attn:
232
+ attn_k, attn_v = cache.k, cache.v
233
+ else:
234
+ Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
235
+ Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
236
+ Xk_BxSxKxH = self.rotary_emb(Xk_BxSxKxH, position=kv_positions) # (B, S, K, H)
237
+
238
+ Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
239
+ Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
240
+
241
+ if cache is None:
242
+ attn_k = Xk_BxKxSxH
243
+ attn_v = Xv_BxKxSxH
244
+ else:
245
+ if prefill:
246
+ attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
247
+ cache.prefill(attn_k, attn_v)
248
+ else:
249
+ attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH)
250
+
251
+ attn_output = F.scaled_dot_product_attention(
252
+ Xq_BxNxTxH,
253
+ attn_k,
254
+ attn_v,
255
+ attn_mask=attn_mask,
256
+ scale=1.0,
257
+ enable_gqa=self.num_gqa_groups > 1,
258
+ is_causal=is_causal,
259
+ )
260
+
261
+ attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
262
+ output = self.o_proj(attn_output)
263
+
264
+ return output.to(original_dtype)
265
+
266
+
267
+ class EncoderLayer(nn.Module):
268
+ """Transformer Encoder Layer using DenseGeneral."""
269
+
270
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
271
+ super().__init__()
272
+ self.config = config
273
+ model_config = config.model
274
+ enc_config = config.model.encoder
275
+ embed_dim = enc_config.n_embd
276
+ self.compute_dtype = compute_dtype
277
+
278
+ self.pre_sa_norm = RMSNorm(
279
+ embed_dim,
280
+ eps=model_config.normalization_layer_epsilon,
281
+ dtype=torch.float32,
282
+ )
283
+ self.self_attention = Attention(
284
+ config,
285
+ q_embed_dim=embed_dim,
286
+ kv_embed_dim=embed_dim,
287
+ num_query_heads=enc_config.n_head,
288
+ num_kv_heads=enc_config.n_head,
289
+ head_dim=enc_config.head_dim,
290
+ compute_dtype=compute_dtype,
291
+ is_cross_attn=False,
292
+ out_embed_dim=embed_dim,
293
+ )
294
+ self.post_sa_norm = RMSNorm(
295
+ embed_dim,
296
+ eps=model_config.normalization_layer_epsilon,
297
+ dtype=torch.float32,
298
+ )
299
+ self.mlp = MlpBlock(embed_dim=embed_dim, intermediate_dim=enc_config.n_hidden, compute_dtype=compute_dtype)
300
+
301
+ def forward(
302
+ self,
303
+ x: torch.Tensor,
304
+ state: EncoderInferenceState,
305
+ ) -> torch.Tensor:
306
+ residual = x
307
+ x_norm = self.pre_sa_norm(x).to(self.compute_dtype)
308
+
309
+ sa_out = self.self_attention(
310
+ Xq=x_norm,
311
+ Xkv=x_norm,
312
+ q_positions=state.positions,
313
+ kv_positions=state.positions,
314
+ attn_mask=state.attn_mask,
315
+ )
316
+ x = residual + sa_out
317
+
318
+ residual = x
319
+ x_norm = self.post_sa_norm(x).to(self.compute_dtype)
320
+ mlp_out = self.mlp(x_norm)
321
+ x = residual + mlp_out
322
+
323
+ return x
324
+
325
+
326
+ class Encoder(nn.Module):
327
+ """Transformer Encoder Stack using DenseGeneral."""
328
+
329
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
330
+ super().__init__()
331
+ self.config = config
332
+ model_config = config.model
333
+ enc_config = config.model.encoder
334
+ self.compute_dtype = compute_dtype
335
+
336
+ self.embedding = nn.Embedding(
337
+ model_config.src_vocab_size,
338
+ enc_config.n_embd,
339
+ dtype=compute_dtype,
340
+ )
341
+ self.layers = nn.ModuleList([EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)])
342
+ self.norm = RMSNorm(
343
+ enc_config.n_embd,
344
+ eps=model_config.normalization_layer_epsilon,
345
+ dtype=torch.float32,
346
+ )
347
+
348
+ def forward(
349
+ self,
350
+ x_ids: torch.Tensor,
351
+ state: EncoderInferenceState,
352
+ ) -> torch.Tensor:
353
+ x = self.embedding(x_ids)
354
+
355
+ for layer in self.layers:
356
+ x = layer(x, state)
357
+
358
+ x = self.norm(x).to(self.compute_dtype)
359
+ return x
360
+
361
+
362
+ class DecoderLayer(nn.Module):
363
+ """Transformer Decoder Layer using DenseGeneral."""
364
+
365
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
366
+ super().__init__()
367
+ self.config = config
368
+ model_config = config.model
369
+ dec_config = config.model.decoder
370
+ enc_config = config.model.encoder
371
+ dec_embed_dim = dec_config.n_embd
372
+ enc_embed_dim = enc_config.n_embd
373
+ self.compute_dtype = compute_dtype
374
+
375
+ # Norms
376
+ self.pre_sa_norm = RMSNorm(
377
+ dec_embed_dim,
378
+ eps=model_config.normalization_layer_epsilon,
379
+ dtype=torch.float32,
380
+ )
381
+ self.pre_ca_norm = RMSNorm(
382
+ dec_embed_dim,
383
+ eps=model_config.normalization_layer_epsilon,
384
+ dtype=torch.float32,
385
+ )
386
+ self.pre_mlp_norm = RMSNorm(
387
+ dec_embed_dim,
388
+ eps=model_config.normalization_layer_epsilon,
389
+ dtype=torch.float32,
390
+ )
391
+
392
+ # Self-Attention (GQA) with Causal Masking
393
+ self.self_attention = Attention(
394
+ config,
395
+ q_embed_dim=dec_embed_dim,
396
+ kv_embed_dim=dec_embed_dim,
397
+ num_query_heads=dec_config.gqa_query_heads,
398
+ num_kv_heads=dec_config.kv_heads,
399
+ head_dim=dec_config.gqa_head_dim,
400
+ compute_dtype=compute_dtype,
401
+ is_cross_attn=False,
402
+ out_embed_dim=dec_embed_dim,
403
+ )
404
+ # Cross-Attention (MHA)
405
+ self.cross_attention = Attention(
406
+ config=config,
407
+ q_embed_dim=dec_embed_dim,
408
+ kv_embed_dim=enc_embed_dim, # Note kv_embed_dim
409
+ num_query_heads=dec_config.cross_query_heads,
410
+ num_kv_heads=dec_config.cross_query_heads,
411
+ head_dim=dec_config.cross_head_dim,
412
+ compute_dtype=compute_dtype,
413
+ is_cross_attn=True,
414
+ out_embed_dim=dec_embed_dim,
415
+ )
416
+ # MLP
417
+ self.mlp = MlpBlock(
418
+ embed_dim=dec_embed_dim,
419
+ intermediate_dim=dec_config.n_hidden,
420
+ compute_dtype=compute_dtype,
421
+ )
422
+
423
+ def forward(
424
+ self,
425
+ x: torch.Tensor,
426
+ state: DecoderInferenceState,
427
+ self_attn_cache: KVCache | None = None,
428
+ cross_attn_cache: KVCache | None = None,
429
+ prefill: bool = False,
430
+ ) -> torch.Tensor:
431
+ residual = x
432
+ x_norm = self.pre_sa_norm(x).to(self.compute_dtype)
433
+
434
+ sa_out = self.self_attention(
435
+ Xq=x_norm, # (2, 1, D)
436
+ Xkv=x_norm, # (2, 1, D)
437
+ q_positions=state.dec_positions, # (2, 1)
438
+ kv_positions=state.dec_positions, # (2, 1)
439
+ attn_mask=None,
440
+ cache=self_attn_cache,
441
+ prefill=prefill,
442
+ is_causal=prefill,
443
+ )
444
+
445
+ x = residual + sa_out
446
+
447
+ residual = x
448
+ x_norm = self.pre_ca_norm(x).to(self.compute_dtype)
449
+ ca_out = self.cross_attention(
450
+ Xq=x_norm,
451
+ Xkv=state.enc_out,
452
+ q_positions=state.dec_positions,
453
+ kv_positions=state.enc_positions,
454
+ attn_mask=state.dec_cross_attn_mask,
455
+ cache=cross_attn_cache,
456
+ )
457
+ x = residual + ca_out
458
+
459
+ residual = x
460
+ x_norm = self.pre_mlp_norm(x).to(self.compute_dtype)
461
+ mlp_out = self.mlp(x_norm)
462
+ x = residual + mlp_out
463
+
464
+ return x
465
+
466
+
467
+ class Decoder(nn.Module):
468
+ """Transformer Decoder Stack using DenseGeneral."""
469
+
470
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
471
+ super().__init__()
472
+ self.config = config
473
+ model_config = config.model
474
+ dec_config = config.model.decoder
475
+ data_config = config.data
476
+ self.num_channels = data_config.channels
477
+ self.num_layers = dec_config.n_layer
478
+
479
+ self.embeddings = nn.ModuleList(
480
+ [
481
+ nn.Embedding(model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype)
482
+ for _ in range(self.num_channels)
483
+ ]
484
+ )
485
+ self.layers = nn.ModuleList(
486
+ [DecoderLayer(config=config, compute_dtype=compute_dtype) for _ in range(self.num_layers)]
487
+ )
488
+
489
+ self.norm = RMSNorm(
490
+ dec_config.n_embd,
491
+ eps=model_config.normalization_layer_epsilon,
492
+ dtype=torch.float32,
493
+ )
494
+
495
+ self.logits_dense = DenseGeneral(
496
+ in_shapes=(dec_config.n_embd,),
497
+ out_features=(self.num_channels, model_config.tgt_vocab_size),
498
+ axis=(-1,),
499
+ weight_dtype=compute_dtype,
500
+ )
501
+
502
+ def precompute_cross_attn_cache(
503
+ self,
504
+ enc_out: torch.Tensor, # (B, S, E)
505
+ enc_positions: torch.Tensor, # (B, S)
506
+ ) -> list[KVCache]:
507
+ """
508
+ Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
509
+ """
510
+ per_layer_kv_cache: list[KVCache] = []
511
+
512
+ for layer in self.layers:
513
+ cross_attn_module = layer.cross_attention
514
+ k_proj = cross_attn_module.k_proj(enc_out)
515
+ v_proj = cross_attn_module.v_proj(enc_out)
516
+
517
+ k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions)
518
+ k = k_proj.transpose(1, 2)
519
+ v = v_proj.transpose(1, 2)
520
+
521
+ per_layer_kv_cache.append(KVCache.from_kv(k, v))
522
+
523
+ return per_layer_kv_cache
524
+
525
+ def decode_step(
526
+ self,
527
+ tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
528
+ state: DecoderInferenceState,
529
+ ) -> torch.Tensor:
530
+ """
531
+ Performs a single decoding step, managing KV caches layer by layer.
532
+
533
+ Returns:
534
+ A tuple containing:
535
+ - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
536
+ """
537
+
538
+ x = None
539
+ for i in range(self.num_channels):
540
+ channel_tokens = tgt_ids_Bx1xC[..., i]
541
+ channel_embed = self.embeddings[i](channel_tokens)
542
+ x = channel_embed if x is None else x + channel_embed
543
+
544
+ for i, layer in enumerate(self.layers):
545
+ self_cache = state.self_attn_cache[i]
546
+ cross_cache = state.cross_attn_cache[i]
547
+ x = layer(
548
+ x, # (2, 1, D)
549
+ state,
550
+ self_attn_cache=self_cache,
551
+ cross_attn_cache=cross_cache,
552
+ )
553
+
554
+ x = self.norm(x)
555
+ logits_Bx1xCxV = self.logits_dense(x)
556
+
557
+ return logits_Bx1xCxV.to(torch.float32)
558
+
559
+ def forward(self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState) -> torch.Tensor:
560
+ """
561
+ Forward pass for the Decoder stack, managing KV caches.
562
+
563
+ Args:
564
+ tgt_ids_BxTxC: Target token IDs (B, T, C).
565
+ encoder_out: Output from the encoder (B, S, E).
566
+ tgt_positions: Positions for target sequence (B, T).
567
+ src_positions: Positions for source sequence (B, S).
568
+ self_attn_mask: Mask for self-attention.
569
+ cross_attn_mask: Mask for cross-attention.
570
+ past_key_values: List containing the self-attention KV cache for each layer
571
+ from the previous decoding step. `len(past_key_values)` should
572
+ equal `num_layers`.
573
+ precomputed_cross_attn_kv: A single tuple containing the pre-computed K/V cache
574
+ derived from `encoder_out`. This is passed identically
575
+ to all layers.
576
+
577
+ Returns:
578
+ A tuple containing:
579
+ - logits: The final output logits (B, T, C * V), cast to float32.
580
+ - present_key_values: A list containing the updated self-attention KV cache
581
+ for each layer for the *current* decoding step.
582
+ """
583
+ _, _, num_channels_in = tgt_ids_BxTxC.shape
584
+ assert num_channels_in == self.num_channels, "Input channels mismatch"
585
+
586
+ # Embeddings
587
+ x = None
588
+ for i in range(self.num_channels):
589
+ channel_tokens = tgt_ids_BxTxC[..., i]
590
+ channel_embed = self.embeddings[i](channel_tokens)
591
+ x = channel_embed if x is None else x + channel_embed
592
+
593
+ for i, layer in enumerate(self.layers):
594
+ self_cache = state.self_attn_cache[i]
595
+ cross_cache = state.cross_attn_cache[i]
596
+ x = layer(x, state, self_attn_cache=self_cache, cross_attn_cache=cross_cache, prefill=True)
597
+
598
+ # Final Norm
599
+ x = self.norm(x)
600
+ logits_BxTxCxV = self.logits_dense(x)
601
+
602
+ return logits_BxTxCxV.to(torch.float32)
603
+
604
+
605
+ class DiaModel(
606
+ nn.Module,
607
+ PyTorchModelHubMixin,
608
+ repo_url="https://github.com/nari-labs/dia",
609
+ pipeline_tag="text-to-speech",
610
+ license="apache-2.0",
611
+ coders={
612
+ DiaConfig: (
613
+ lambda x: x.model_dump(),
614
+ lambda data: DiaConfig.model_validate(data),
615
+ ),
616
+ },
617
+ ):
618
+ """PyTorch Dia Model using DenseGeneral."""
619
+
620
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
621
+ super().__init__()
622
+ self.config = config
623
+ self.encoder = Encoder(config, compute_dtype)
624
+ self.decoder = Decoder(config, compute_dtype)
dia/model.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from enum import Enum
3
+
4
+ import dac
5
+ import numpy as np
6
+ import torch
7
+ import torchaudio
8
+
9
+ from .audio import apply_audio_delay, build_delay_indices, build_revert_indices, decode, revert_audio_delay
10
+ from .config import DiaConfig
11
+ from .layers import DiaModel
12
+ from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
13
+
14
+
15
+ DEFAULT_SAMPLE_RATE = 44100
16
+
17
+
18
+ def _get_default_device():
19
+ if torch.cuda.is_available():
20
+ return torch.device("cuda")
21
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
22
+ return torch.device("mps")
23
+ return torch.device("cpu")
24
+
25
+
26
+ def _sample_next_token(
27
+ logits_BCxV: torch.Tensor,
28
+ temperature: float,
29
+ top_p: float,
30
+ cfg_filter_top_k: int | None = None,
31
+ ) -> torch.Tensor:
32
+ if temperature == 0.0:
33
+ return torch.argmax(logits_BCxV, dim=-1)
34
+
35
+ logits_BCxV = logits_BCxV / temperature
36
+ if cfg_filter_top_k is not None:
37
+ _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
38
+ mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
39
+ mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
40
+ logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf)
41
+
42
+ if top_p < 1.0:
43
+ probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
44
+ sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True)
45
+ cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
46
+
47
+ sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
48
+ sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[..., :-1].clone()
49
+ sorted_indices_to_remove_BCxV[..., 0] = 0
50
+
51
+ indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
52
+ indices_to_remove_BCxV.scatter_(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV)
53
+ logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
54
+
55
+ final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
56
+
57
+ sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1)
58
+ sampled_indices_C = sampled_indices_BC.squeeze(-1)
59
+ return sampled_indices_C
60
+
61
+
62
+ class ComputeDtype(str, Enum):
63
+ FLOAT32 = "float32"
64
+ FLOAT16 = "float16"
65
+ BFLOAT16 = "bfloat16"
66
+
67
+ def to_dtype(self) -> torch.dtype:
68
+ if self == ComputeDtype.FLOAT32:
69
+ return torch.float32
70
+ elif self == ComputeDtype.FLOAT16:
71
+ return torch.float16
72
+ elif self == ComputeDtype.BFLOAT16:
73
+ return torch.bfloat16
74
+ else:
75
+ raise ValueError(f"Unsupported compute dtype: {self}")
76
+
77
+
78
+ class Dia:
79
+ def __init__(
80
+ self,
81
+ config: DiaConfig,
82
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
83
+ device: torch.device | None = None,
84
+ ):
85
+ """Initializes the Dia model.
86
+
87
+ Args:
88
+ config: The configuration object for the model.
89
+ device: The device to load the model onto. If None, will automatically select the best available device.
90
+
91
+ Raises:
92
+ RuntimeError: If there is an error loading the DAC model.
93
+ """
94
+ super().__init__()
95
+ self.config = config
96
+ self.device = device if device is not None else _get_default_device()
97
+ if isinstance(compute_dtype, str):
98
+ compute_dtype = ComputeDtype(compute_dtype)
99
+ self.compute_dtype = compute_dtype.to_dtype()
100
+ self.model = DiaModel(config, self.compute_dtype)
101
+ self.dac_model = None
102
+
103
+ @classmethod
104
+ def from_local(
105
+ cls,
106
+ config_path: str,
107
+ checkpoint_path: str,
108
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
109
+ device: torch.device | None = None,
110
+ ) -> "Dia":
111
+ """Loads the Dia model from local configuration and checkpoint files.
112
+
113
+ Args:
114
+ config_path: Path to the configuration JSON file.
115
+ checkpoint_path: Path to the model checkpoint (.pth) file.
116
+ device: The device to load the model onto. If None, will automatically select the best available device.
117
+
118
+ Returns:
119
+ An instance of the Dia model loaded with weights and set to eval mode.
120
+
121
+ Raises:
122
+ FileNotFoundError: If the config or checkpoint file is not found.
123
+ RuntimeError: If there is an error loading the checkpoint.
124
+ """
125
+ config = DiaConfig.load(config_path)
126
+ if config is None:
127
+ raise FileNotFoundError(f"Config file not found at {config_path}")
128
+
129
+ dia = cls(config, compute_dtype, device)
130
+
131
+ try:
132
+ state_dict = torch.load(checkpoint_path, map_location=dia.device)
133
+ dia.model.load_state_dict(state_dict)
134
+ except FileNotFoundError:
135
+ raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
136
+ except Exception as e:
137
+ raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}") from e
138
+
139
+ dia.model.to(dia.device)
140
+ dia.model.eval()
141
+ dia._load_dac_model()
142
+ return dia
143
+
144
+ @classmethod
145
+ def from_pretrained(
146
+ cls,
147
+ model_name: str = "nari-labs/Dia-1.6B",
148
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
149
+ device: torch.device | None = None,
150
+ ) -> "Dia":
151
+ """Loads the Dia model from a Hugging Face Hub repository.
152
+
153
+ Downloads the configuration and checkpoint files from the specified
154
+ repository ID and then loads the model.
155
+
156
+ Args:
157
+ model_name: The Hugging Face Hub repository ID (e.g., "nari-labs/Dia-1.6B").
158
+ compute_dtype: The computation dtype to use.
159
+ device: The device to load the model onto. If None, will automatically select the best available device.
160
+
161
+ Returns:
162
+ An instance of the Dia model loaded with weights and set to eval mode.
163
+
164
+ Raises:
165
+ FileNotFoundError: If config or checkpoint download/loading fails.
166
+ RuntimeError: If there is an error loading the checkpoint.
167
+ """
168
+ if isinstance(compute_dtype, str):
169
+ compute_dtype = ComputeDtype(compute_dtype)
170
+ loaded_model = DiaModel.from_pretrained(model_name, compute_dtype=compute_dtype.to_dtype())
171
+ config = loaded_model.config
172
+ dia = cls(config, compute_dtype, device)
173
+
174
+ dia.model = loaded_model
175
+ dia.model.to(dia.device)
176
+ dia.model.eval()
177
+ dia._load_dac_model()
178
+ return dia
179
+
180
+ def _load_dac_model(self):
181
+ try:
182
+ dac_model_path = dac.utils.download()
183
+ dac_model = dac.DAC.load(dac_model_path).to(self.device)
184
+ except Exception as e:
185
+ raise RuntimeError("Failed to load DAC model") from e
186
+ self.dac_model = dac_model
187
+
188
+ def _prepare_text_input(self, text: str) -> torch.Tensor:
189
+ """Encodes text prompt, pads, and creates attention mask and positions."""
190
+ text_pad_value = self.config.data.text_pad_value
191
+ max_len = self.config.data.text_length
192
+
193
+ byte_text = text.encode("utf-8")
194
+ replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02")
195
+ text_tokens = list(replaced_bytes)
196
+
197
+ current_len = len(text_tokens)
198
+ padding_needed = max_len - current_len
199
+ if padding_needed <= 0:
200
+ text_tokens = text_tokens[:max_len]
201
+ padded_text_np = np.array(text_tokens, dtype=np.uint8)
202
+ else:
203
+ padded_text_np = np.pad(
204
+ text_tokens,
205
+ (0, padding_needed),
206
+ mode="constant",
207
+ constant_values=text_pad_value,
208
+ ).astype(np.uint8)
209
+
210
+ src_tokens = torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) # [1, S]
211
+ return src_tokens
212
+
213
+ def _prepare_audio_prompt(self, audio_prompt: torch.Tensor | None) -> tuple[torch.Tensor, int]:
214
+ num_channels = self.config.data.channels
215
+ audio_bos_value = self.config.data.audio_bos_value
216
+ audio_pad_value = self.config.data.audio_pad_value
217
+ delay_pattern = self.config.data.delay_pattern
218
+ max_delay_pattern = max(delay_pattern)
219
+
220
+ prefill = torch.full(
221
+ (1, num_channels),
222
+ fill_value=audio_bos_value,
223
+ dtype=torch.int,
224
+ device=self.device,
225
+ )
226
+
227
+ prefill_step = 1
228
+
229
+ if audio_prompt is not None:
230
+ prefill_step += audio_prompt.shape[0]
231
+ prefill = torch.cat([prefill, audio_prompt], dim=0)
232
+
233
+ delay_pad_tensor = torch.full(
234
+ (max_delay_pattern, num_channels), fill_value=-1, dtype=torch.int, device=self.device
235
+ )
236
+ prefill = torch.cat([prefill, delay_pad_tensor], dim=0)
237
+
238
+ delay_precomp = build_delay_indices(
239
+ B=1,
240
+ T=prefill.shape[0],
241
+ C=num_channels,
242
+ delay_pattern=delay_pattern,
243
+ )
244
+
245
+ prefill = apply_audio_delay(
246
+ audio_BxTxC=prefill.unsqueeze(0),
247
+ pad_value=audio_pad_value,
248
+ bos_value=audio_bos_value,
249
+ precomp=delay_precomp,
250
+ ).squeeze(0)
251
+
252
+ return prefill, prefill_step
253
+
254
+ def _prepare_generation(self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool):
255
+ enc_input_cond = self._prepare_text_input(text)
256
+ enc_input_uncond = torch.zeros_like(enc_input_cond)
257
+ enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0)
258
+
259
+ if isinstance(audio_prompt, str):
260
+ audio_prompt = self.load_audio(audio_prompt)
261
+ prefill, prefill_step = self._prepare_audio_prompt(audio_prompt)
262
+
263
+ if verbose:
264
+ print("generate: data loaded")
265
+
266
+ enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
267
+ encoder_out = self.model.encoder(enc_input, enc_state)
268
+
269
+ dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(encoder_out, enc_state.positions)
270
+ dec_state = DecoderInferenceState.new(
271
+ self.config, enc_state, encoder_out, dec_cross_attn_cache, self.compute_dtype
272
+ )
273
+ dec_output = DecoderOutput.new(self.config, self.device)
274
+ dec_output.prefill(prefill, prefill_step)
275
+
276
+ dec_step = prefill_step - 1
277
+ if dec_step > 0:
278
+ dec_state.prepare_step(0, dec_step)
279
+ tokens_BxTxC = dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1)
280
+ self.model.decoder.forward(tokens_BxTxC, dec_state)
281
+
282
+ return dec_state, dec_output
283
+
284
+ def _decoder_step(
285
+ self,
286
+ tokens_Bx1xC: torch.Tensor,
287
+ dec_state: DecoderInferenceState,
288
+ cfg_scale: float,
289
+ temperature: float,
290
+ top_p: float,
291
+ cfg_filter_top_k: int,
292
+ ) -> torch.Tensor:
293
+ audio_eos_value = self.config.data.audio_eos_value
294
+ logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state)
295
+
296
+ logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :]
297
+ uncond_logits_CxV = logits_last_BxCxV[0, :, :]
298
+ cond_logits_CxV = logits_last_BxCxV[1, :, :]
299
+
300
+ logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
301
+ logits_CxV[:, audio_eos_value + 1 :] = -torch.inf
302
+ logits_CxV[1:, audio_eos_value:] = -torch.inf
303
+
304
+ pred_C = _sample_next_token(
305
+ logits_CxV.float(),
306
+ temperature=temperature,
307
+ top_p=top_p,
308
+ cfg_filter_top_k=cfg_filter_top_k,
309
+ )
310
+ return pred_C
311
+
312
+ def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray:
313
+ num_channels = self.config.data.channels
314
+ seq_length = generated_codes.shape[0]
315
+ delay_pattern = self.config.data.delay_pattern
316
+ audio_pad_value = self.config.data.audio_pad_value
317
+ max_delay_pattern = max(delay_pattern)
318
+
319
+ revert_precomp = build_revert_indices(
320
+ B=1,
321
+ T=seq_length,
322
+ C=num_channels,
323
+ delay_pattern=delay_pattern,
324
+ )
325
+
326
+ codebook = revert_audio_delay(
327
+ audio_BxTxC=generated_codes.unsqueeze(0),
328
+ pad_value=audio_pad_value,
329
+ precomp=revert_precomp,
330
+ T=seq_length,
331
+ )[:, :-max_delay_pattern, :]
332
+
333
+ min_valid_index = 0
334
+ max_valid_index = 1023
335
+ invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
336
+ codebook[invalid_mask] = 0
337
+
338
+ audio = decode(self.dac_model, codebook.transpose(1, 2))
339
+
340
+ return audio.squeeze().cpu().numpy()
341
+
342
+ def load_audio(self, audio_path: str) -> torch.Tensor:
343
+ audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
344
+ if sr != DEFAULT_SAMPLE_RATE:
345
+ audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
346
+ audio = audio.to(self.device).unsqueeze(0) # 1, C, T
347
+ audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
348
+ _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T
349
+ return encoded_frame.squeeze(0).transpose(0, 1)
350
+
351
+ def save_audio(self, path: str, audio: np.ndarray):
352
+ import soundfile as sf
353
+
354
+ sf.write(path, audio, DEFAULT_SAMPLE_RATE)
355
+
356
+ @torch.inference_mode()
357
+ def generate(
358
+ self,
359
+ text: str,
360
+ max_tokens: int | None = None,
361
+ cfg_scale: float = 3.0,
362
+ temperature: float = 1.3,
363
+ top_p: float = 0.95,
364
+ use_torch_compile: bool = False,
365
+ cfg_filter_top_k: int = 35,
366
+ audio_prompt: str | torch.Tensor | None = None,
367
+ audio_prompt_path: str | None = None,
368
+ use_cfg_filter: bool | None = None,
369
+ verbose: bool = False,
370
+ ) -> np.ndarray:
371
+ audio_eos_value = self.config.data.audio_eos_value
372
+ audio_pad_value = self.config.data.audio_pad_value
373
+ delay_pattern = self.config.data.delay_pattern
374
+ max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
375
+ max_delay_pattern = max(delay_pattern)
376
+ self.model.eval()
377
+
378
+ if audio_prompt_path:
379
+ print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.")
380
+ audio_prompt = audio_prompt_path
381
+ if use_cfg_filter is not None:
382
+ print("Warning: use_cfg_filter is deprecated.")
383
+
384
+ if verbose:
385
+ total_start_time = time.time()
386
+
387
+ dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose)
388
+ dec_step = dec_output.prefill_step - 1
389
+
390
+ bos_countdown = max_delay_pattern
391
+ eos_detected = False
392
+ eos_countdown = -1
393
+
394
+ if use_torch_compile:
395
+ step_fn = torch.compile(self._decoder_step, mode="default")
396
+ else:
397
+ step_fn = self._decoder_step
398
+
399
+ if verbose:
400
+ print("generate: starting generation loop")
401
+ if use_torch_compile:
402
+ print("generate: by using use_torch_compile=True, the first step would take long")
403
+ start_time = time.time()
404
+
405
+ while dec_step < max_tokens:
406
+ dec_state.prepare_step(dec_step)
407
+ tokens_Bx1xC = dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1)
408
+ pred_C = step_fn(
409
+ tokens_Bx1xC,
410
+ dec_state,
411
+ cfg_scale,
412
+ temperature,
413
+ top_p,
414
+ cfg_filter_top_k,
415
+ )
416
+
417
+ if (not eos_detected and pred_C[0] == audio_eos_value) or dec_step == max_tokens - max_delay_pattern - 1:
418
+ eos_detected = True
419
+ eos_countdown = max_delay_pattern
420
+
421
+ if eos_countdown > 0:
422
+ step_after_eos = max_delay_pattern - eos_countdown
423
+ for i, d in enumerate(delay_pattern):
424
+ if step_after_eos == d:
425
+ pred_C[i] = audio_eos_value
426
+ elif step_after_eos > d:
427
+ pred_C[i] = audio_pad_value
428
+ eos_countdown -= 1
429
+
430
+ bos_countdown = max(0, bos_countdown - 1)
431
+ dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0)
432
+
433
+ if eos_countdown == 0:
434
+ break
435
+
436
+ dec_step += 1
437
+ if verbose and dec_step % 86 == 0:
438
+ duration = time.time() - start_time
439
+ print(
440
+ f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x"
441
+ )
442
+ start_time = time.time()
443
+
444
+ if dec_output.prefill_step >= dec_step + 1:
445
+ print("Warning: Nothing generated")
446
+ return None
447
+
448
+ generated_codes = dec_output.generated_tokens[dec_output.prefill_step : dec_step + 1, :]
449
+
450
+ if verbose:
451
+ total_step = dec_step + 1 - dec_output.prefill_step
452
+ total_duration = time.time() - total_start_time
453
+ print(f"generate: total step={total_step}, total duration={total_duration:.3f}s")
454
+
455
+ return self._generate_output(generated_codes)
dia/state.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from .config import DiaConfig
6
+
7
+
8
+ def create_attn_mask(
9
+ q_padding_mask_1d: torch.Tensor,
10
+ k_padding_mask_1d: torch.Tensor,
11
+ device: torch.device,
12
+ is_causal: bool = False,
13
+ ) -> torch.Tensor:
14
+ """
15
+ Creates the attention mask (self or cross) mimicking JAX segment ID logic.
16
+ """
17
+ B1, Tq = q_padding_mask_1d.shape
18
+ B2, Tk = k_padding_mask_1d.shape
19
+ assert B1 == B2, "Query and key batch dimensions must match"
20
+
21
+ p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
22
+ p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
23
+
24
+ # Condition A: Non-padding query attends to non-padding key
25
+ non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
26
+
27
+ # Condition B: Padding query attends to padding key
28
+ pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
29
+
30
+ # Combine: True if padding status is compatible (both non-pad OR both pad)
31
+ mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
32
+
33
+ if is_causal:
34
+ assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
35
+ causal_mask_2d = torch.tril(torch.ones((Tq, Tk), dtype=torch.bool, device=device)) # Shape [Tq, Tk]
36
+ causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
37
+ return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
38
+ else:
39
+ return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
40
+
41
+
42
+ @dataclass
43
+ class EncoderInferenceState:
44
+ """Parameters specifically for encoder inference."""
45
+
46
+ max_seq_len: int
47
+ device: torch.device
48
+ positions: torch.Tensor
49
+ padding_mask: torch.Tensor
50
+ attn_mask: torch.Tensor
51
+
52
+ @classmethod
53
+ def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceState":
54
+ """Creates EtorchrInferenceParams from DiaConfig and a device."""
55
+ device = cond_src.device
56
+
57
+ positions = (
58
+ torch.arange(config.data.text_length, dtype=torch.float32, device=device).unsqueeze(0).expand(2, -1)
59
+ )
60
+ padding_mask = (cond_src != config.data.text_pad_value).to(device).expand(2, -1)
61
+ attn_mask = create_attn_mask(padding_mask, padding_mask, device, is_causal=False)
62
+
63
+ return cls(
64
+ max_seq_len=config.data.text_length,
65
+ device=device,
66
+ positions=positions,
67
+ padding_mask=padding_mask,
68
+ attn_mask=attn_mask,
69
+ )
70
+
71
+
72
+ class KVCache:
73
+ def __init__(
74
+ self,
75
+ num_heads: int,
76
+ max_len: int,
77
+ head_dim: int,
78
+ dtype: torch.dtype,
79
+ device: torch.device,
80
+ k: torch.Tensor | None = None,
81
+ v: torch.Tensor | None = None,
82
+ ):
83
+ self.k = torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device) if k is None else k
84
+ self.v = torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device) if v is None else v
85
+ self.current_idx = torch.tensor(0)
86
+
87
+ @classmethod
88
+ def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
89
+ return cls(
90
+ num_heads=k.shape[1],
91
+ max_len=k.shape[2],
92
+ head_dim=k.shape[3],
93
+ dtype=k.dtype,
94
+ device=k.device,
95
+ k=k,
96
+ v=v,
97
+ )
98
+
99
+ def update(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
100
+ self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
101
+ self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
102
+ self.current_idx += 1
103
+ return self.k[:, :, : self.current_idx, :], self.v[:, :, : self.current_idx, :]
104
+
105
+ def prefill(self, k: torch.Tensor, v: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
106
+ prefill_len = k.shape[2]
107
+ self.k[:, :, :prefill_len, :] = k
108
+ self.v[:, :, :prefill_len, :] = v
109
+ self.current_idx = prefill_len - 1
110
+
111
+
112
+ @dataclass
113
+ class DecoderInferenceState:
114
+ """Parameters specifically for decoder inference."""
115
+
116
+ device: torch.device
117
+ dtype: torch.dtype
118
+ enc_out: torch.Tensor
119
+ enc_positions: torch.Tensor
120
+ dec_positions: torch.Tensor
121
+ dec_cross_attn_mask: torch.Tensor
122
+ self_attn_cache: list[KVCache]
123
+ cross_attn_cache: list[KVCache]
124
+
125
+ @classmethod
126
+ def new(
127
+ cls,
128
+ config: DiaConfig,
129
+ enc_state: EncoderInferenceState,
130
+ enc_out: torch.Tensor,
131
+ dec_cross_attn_cache: list[KVCache],
132
+ compute_dtype: torch.dtype,
133
+ ) -> "DecoderInferenceState":
134
+ """Creates DecoderInferenceParams from DiaConfig and a device."""
135
+ device = enc_out.device
136
+ max_audio_len = config.data.audio_length
137
+
138
+ dec_positions = torch.full((2, 1), fill_value=0, dtype=torch.long, device=device)
139
+ tgt_padding_mask = torch.ones((2, 1), dtype=torch.bool, device=device)
140
+ dec_cross_attn_mask = create_attn_mask(tgt_padding_mask, enc_state.padding_mask, device, is_causal=False)
141
+
142
+ self_attn_cache = [
143
+ KVCache(
144
+ config.model.decoder.kv_heads,
145
+ max_audio_len,
146
+ config.model.decoder.gqa_head_dim,
147
+ compute_dtype,
148
+ device,
149
+ )
150
+ for _ in range(config.model.decoder.n_layer)
151
+ ]
152
+
153
+ return cls(
154
+ device=device,
155
+ dtype=compute_dtype,
156
+ enc_out=enc_out,
157
+ enc_positions=enc_state.positions,
158
+ dec_positions=dec_positions,
159
+ dec_cross_attn_mask=dec_cross_attn_mask,
160
+ self_attn_cache=self_attn_cache,
161
+ cross_attn_cache=dec_cross_attn_cache,
162
+ )
163
+
164
+ def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
165
+ if step_to is None:
166
+ step_to = step_from + 1
167
+ self.dec_positions = (
168
+ torch.arange(step_from, step_to, dtype=torch.float32, device=self.device).unsqueeze(0).expand(2, -1)
169
+ )
170
+
171
+
172
+ @dataclass
173
+ class DecoderOutput:
174
+ generated_tokens: torch.Tensor
175
+ prefill_step: int
176
+
177
+ @classmethod
178
+ def new(cls, config: DiaConfig, device: torch.device) -> "DecoderOutput":
179
+ max_audio_len = config.data.audio_length
180
+ return cls(
181
+ generated_tokens=torch.full(
182
+ (max_audio_len, config.data.channels),
183
+ fill_value=-1,
184
+ dtype=torch.int,
185
+ device=device,
186
+ ),
187
+ prefill_step=0,
188
+ )
189
+
190
+ def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
191
+ if step_to is None:
192
+ step_to = step_from + 1
193
+ return self.generated_tokens[step_from:step_to, :]
194
+
195
+ def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
196
+ if apply_mask:
197
+ mask = self.generated_tokens[step : step + 1, :] == -1
198
+ self.generated_tokens[step : step + 1, :] = torch.where(
199
+ mask, dec_out, self.generated_tokens[step : step + 1, :]
200
+ )
201
+ else:
202
+ self.generated_tokens[step : step + 1, :] = dec_out
203
+
204
+ def prefill(self, dec_out: torch.Tensor, prefill_step: int):
205
+ length = dec_out.shape[0]
206
+ self.generated_tokens[0:length, :] = dec_out
207
+ self.prefill_step = prefill_step