buttercrab commited on
Commit
4aa0f34
·
1 Parent(s): 9577cb2

update to faster inference

Browse files
Files changed (6) hide show
  1. app.py +17 -31
  2. dia/audio.py +27 -104
  3. dia/config.py +17 -26
  4. dia/layers.py +106 -337
  5. dia/model.py +314 -257
  6. dia/state.py +234 -0
app.py CHANGED
@@ -1,9 +1,7 @@
1
- import argparse
2
  import tempfile
3
  import time
4
  from pathlib import Path
5
  from typing import Optional, Tuple
6
- import spaces
7
 
8
  import gradio as gr
9
  import numpy as np
@@ -12,40 +10,17 @@ import torch
12
 
13
  from dia.model import Dia
14
 
15
- # --- Global Setup ---
16
- parser = argparse.ArgumentParser(description="Gradio interface for Nari TTS")
17
- parser.add_argument(
18
- "--device", type=str, default=None, help="Force device (e.g., 'cuda', 'mps', 'cpu')"
19
- )
20
- parser.add_argument("--share", action="store_true", help="Enable Gradio sharing")
21
-
22
- args = parser.parse_args()
23
-
24
-
25
- # Determine device
26
- if args.device:
27
- device = torch.device(args.device)
28
- elif torch.cuda.is_available():
29
- device = torch.device("cuda")
30
- # Simplified MPS check for broader compatibility
31
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
32
- # Basic check is usually sufficient, detailed check can be problematic
33
- device = torch.device("mps")
34
- else:
35
- device = torch.device("cpu")
36
-
37
- print(f"Using device: {device}")
38
 
39
  # Load Nari model and config
40
  print("Loading Nari model...")
41
  try:
42
  # Use the function from inference.py
43
- model = Dia.from_pretrained("nari-labs/Dia-1.6B")
44
  except Exception as e:
45
  print(f"Error loading Nari model: {e}")
46
  raise
47
 
48
- @spaces.GPU
49
  def run_inference(
50
  text_input: str,
51
  audio_prompt_input: Optional[Tuple[int, np.ndarray]],
@@ -60,7 +35,7 @@ def run_inference(
60
  Runs Nari inference using the globally loaded model and provided inputs.
61
  Uses temporary files for text and audio prompt compatibility with inference.generate.
62
  """
63
- # global model, device # Access global model, config, device
64
 
65
  if not text_input or text_input.isspace():
66
  raise gr.Error("Text input cannot be empty.")
@@ -146,10 +121,9 @@ def run_inference(
146
  cfg_scale=cfg_scale,
147
  temperature=temperature,
148
  top_p=top_p,
149
- use_cfg_filter=True,
150
  cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
151
  use_torch_compile=False, # Keep False for Gradio stability
152
- audio_prompt_path=prompt_path_for_generate,
153
  )
154
 
155
  end_time = time.time()
@@ -192,6 +166,16 @@ def run_inference(
192
  f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
193
  )
194
 
 
 
 
 
 
 
 
 
 
 
195
  else:
196
  print("\nGeneration finished, but no valid tokens were produced.")
197
  # Return default silence
@@ -383,8 +367,10 @@ with gr.Blocks(css=css) as demo:
383
  else:
384
  gr.Markdown("_(No examples configured or example prompt file missing)_")
385
 
386
-
387
  # --- Launch the App ---
388
  if __name__ == "__main__":
389
  print("Launching Gradio interface...")
 
 
 
390
  demo.launch()
 
 
1
  import tempfile
2
  import time
3
  from pathlib import Path
4
  from typing import Optional, Tuple
 
5
 
6
  import gradio as gr
7
  import numpy as np
 
10
 
11
  from dia.model import Dia
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Load Nari model and config
15
  print("Loading Nari model...")
16
  try:
17
  # Use the function from inference.py
18
+ model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="bfloat16")
19
  except Exception as e:
20
  print(f"Error loading Nari model: {e}")
21
  raise
22
 
23
+
24
  def run_inference(
25
  text_input: str,
26
  audio_prompt_input: Optional[Tuple[int, np.ndarray]],
 
35
  Runs Nari inference using the globally loaded model and provided inputs.
36
  Uses temporary files for text and audio prompt compatibility with inference.generate.
37
  """
38
+ global model, device # Access global model, config, device
39
 
40
  if not text_input or text_input.isspace():
41
  raise gr.Error("Text input cannot be empty.")
 
121
  cfg_scale=cfg_scale,
122
  temperature=temperature,
123
  top_p=top_p,
 
124
  cfg_filter_top_k=cfg_filter_top_k, # Pass the value here
125
  use_torch_compile=False, # Keep False for Gradio stability
126
+ audio_prompt=prompt_path_for_generate,
127
  )
128
 
129
  end_time = time.time()
 
166
  f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}"
167
  )
168
 
169
+ # Explicitly convert to int16 to prevent Gradio warning
170
+ if (
171
+ output_audio[1].dtype == np.float32
172
+ or output_audio[1].dtype == np.float64
173
+ ):
174
+ audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0)
175
+ audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16)
176
+ output_audio = (output_sr, audio_for_gradio)
177
+ print("Converted audio to int16 for Gradio output.")
178
+
179
  else:
180
  print("\nGeneration finished, but no valid tokens were produced.")
181
  # Return default silence
 
367
  else:
368
  gr.Markdown("_(No examples configured or example prompt file missing)_")
369
 
 
370
  # --- Launch the App ---
371
  if __name__ == "__main__":
372
  print("Launching Gradio interface...")
373
+
374
+ # set `GRADIO_SERVER_NAME`, `GRADIO_SERVER_PORT` env vars to override default values
375
+ # use `GRADIO_SERVER_NAME=0.0.0.0` for Docker
376
  demo.launch()
dia/audio.py CHANGED
@@ -2,10 +2,10 @@ import typing as tp
2
 
3
  import torch
4
 
5
- from .config import DataConfig
6
 
7
-
8
- def build_delay_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
 
9
  """
10
  Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
11
  Negative t_idx => BOS; t_idx >= T => PAD.
@@ -69,7 +69,9 @@ def apply_audio_delay(
69
 
70
  # Equivalent of tf.gather_nd using advanced indexing
71
  # Ensure indices are long type if not already (build_delay_indices should handle this)
72
- gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
 
 
73
  gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
74
 
75
  # Create masks on the correct device
@@ -82,65 +84,16 @@ def apply_audio_delay(
82
 
83
  # If mask_bos, BOS; else if mask_pad, PAD; else original gather
84
  # All tensors should now be on the same device
85
- result_BxTxC = torch.where(mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC))
86
-
87
- return result_BxTxC
88
-
89
-
90
- @torch.no_grad()
91
- @torch.inference_mode()
92
- def audio_to_codebook(
93
- model,
94
- input_values,
95
- data_config: DataConfig,
96
- padding_mask=None,
97
- sample_rate=44100,
98
- ):
99
- """
100
- Encodes the input audio waveform into discrete codes.
101
-
102
- Args:
103
- model: The model to use for encoding.
104
- input_values (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
105
- Float values of the input audio waveform.
106
- padding_mask (`torch.Tensor` of shape `(batch_size, channels, sequence_length)`):
107
- Padding mask used to pad the `input_values`.
108
- sample_rate (`int`, *optional*) :
109
- Signal sampling_rate
110
-
111
- Returns:
112
- A list of frames containing the discrete encoded codes for the input audio waveform, along with rescaling
113
- factors for each chunk when `normalize` is True. Each frames is a tuple `(codebook, scale)`, with
114
- `codebook` of shape `[batch_size, num_codebooks, frames]`.
115
- Scale is not used here.
116
-
117
- """
118
- audio_data = model.preprocess(input_values, sample_rate)
119
-
120
- if padding_mask is None:
121
- padding_mask = torch.ones_like(input_values).bool()
122
-
123
- _, encoded_frame, _, _, _ = model.encode(audio_data, n_quantizers=None) # 1, C, T
124
- seq_length = encoded_frame.shape[2]
125
-
126
- t_idx_BxTxC, indices_BTCx3 = build_delay_indices(
127
- B=1,
128
- T=seq_length,
129
- C=data_config.channels,
130
- delay_pattern=data_config.delay_pattern,
131
  )
132
 
133
- encoded_frame = apply_audio_delay(
134
- audio_BxTxC=encoded_frame.transpose(1, 2), # 1, T, C
135
- pad_value=data_config.audio_pad_value,
136
- bos_value=data_config.audio_bos_value,
137
- precomp=(t_idx_BxTxC, indices_BTCx3),
138
- )
139
-
140
- return encoded_frame
141
 
142
 
143
- def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) -> tp.Tuple[torch.Tensor, torch.Tensor]:
 
 
144
  """
145
  Precompute indices for the revert operation using PyTorch.
146
 
@@ -162,8 +115,12 @@ def build_revert_indices(B: int, T: int, C: int, delay_pattern: tp.List[int]) ->
162
  t_idx_BT1 + delay_arr.view(1, 1, C),
163
  torch.tensor(T - 1, device=device),
164
  )
165
- b_idx_BxTxC = torch.broadcast_to(torch.arange(B, device=device).view(B, 1, 1), [B, T, C])
166
- c_idx_BxTxC = torch.broadcast_to(torch.arange(C, device=device).view(1, 1, C), [B, T, C])
 
 
 
 
167
 
168
  indices_BTCx3 = torch.stack(
169
  [
@@ -205,15 +162,21 @@ def revert_audio_delay(
205
  indices_BTCx3 = indices_BTCx3.to(device)
206
 
207
  # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
208
- gathered_flat = audio_BxTxC[indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]]
209
- gathered_BxTxC = gathered_flat.view(audio_BxTxC.size()) # Use .size() for robust reshaping
 
 
 
 
210
 
211
  # Create pad_tensor on the correct device
212
  pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
213
  # Create T tensor on the correct device for comparison
214
  T_tensor = torch.tensor(T, device=device)
215
 
216
- result_BxTxC = torch.where(t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC) # Changed np.where to torch.where
 
 
217
 
218
  return result_BxTxC
219
 
@@ -238,43 +201,3 @@ def decode(
238
  except Exception as e:
239
  print(f"Error in decode method: {str(e)}")
240
  raise
241
-
242
-
243
- def codebook_to_audio(generated_codes: torch.Tensor, model, delay_pattern, B=1, T=2600, C=9):
244
- """Process a single codebook file to generate audio"""
245
- # Remove BOS token
246
- generated_codes = generated_codes[:, 1:]
247
-
248
- if generated_codes.shape[1] > T:
249
- generated_codes = generated_codes[:, :T]
250
-
251
- seq_length = generated_codes.shape[1]
252
-
253
- # Build revert indices
254
- t_idx_BxTxC, indices_BTCx3 = build_revert_indices(B=B, T=seq_length, C=C, delay_pattern=delay_pattern)
255
-
256
- # Transpose and add batch dimension
257
- audio_BxTxC = generated_codes.transpose(1, 0).unsqueeze(0)
258
- reverted_codebook = revert_audio_delay(
259
- audio_BxTxC=audio_BxTxC,
260
- pad_value=0,
261
- precomp=(t_idx_BxTxC, indices_BTCx3),
262
- T=seq_length,
263
- )
264
- reverted_codebook = reverted_codebook[:, :-30, :]
265
-
266
- codebook = reverted_codebook.transpose(1, 2)
267
-
268
- min_valid_index = 0
269
- max_valid_index = 1023
270
- invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
271
-
272
- num_invalid = torch.sum(invalid_mask).item()
273
- if num_invalid > 0:
274
- print(f"Warning: Clamping {num_invalid} indices outside range [{min_valid_index}, {max_valid_index}] to 0.")
275
-
276
- # Set invalid values to 0 (modify the tensor in-place)
277
- codebook[invalid_mask] = 0
278
- audio_array = decode(model, codebook)
279
-
280
- return audio_array
 
2
 
3
  import torch
4
 
 
5
 
6
+ def build_delay_indices(
7
+ B: int, T: int, C: int, delay_pattern: tp.List[int]
8
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
9
  """
10
  Precompute (t_idx_BxTxC, indices_BTCx3) so that out[t, c] = in[t - delay[c], c].
11
  Negative t_idx => BOS; t_idx >= T => PAD.
 
69
 
70
  # Equivalent of tf.gather_nd using advanced indexing
71
  # Ensure indices are long type if not already (build_delay_indices should handle this)
72
+ gathered_flat = audio_BxTxC[
73
+ indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
74
+ ]
75
  gathered_BxTxC = gathered_flat.view(audio_BxTxC.shape)
76
 
77
  # Create masks on the correct device
 
84
 
85
  # If mask_bos, BOS; else if mask_pad, PAD; else original gather
86
  # All tensors should now be on the same device
87
+ result_BxTxC = torch.where(
88
+ mask_bos, bos_tensor, torch.where(mask_pad, pad_tensor, gathered_BxTxC)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
90
 
91
+ return result_BxTxC
 
 
 
 
 
 
 
92
 
93
 
94
+ def build_revert_indices(
95
+ B: int, T: int, C: int, delay_pattern: tp.List[int]
96
+ ) -> tp.Tuple[torch.Tensor, torch.Tensor]:
97
  """
98
  Precompute indices for the revert operation using PyTorch.
99
 
 
115
  t_idx_BT1 + delay_arr.view(1, 1, C),
116
  torch.tensor(T - 1, device=device),
117
  )
118
+ b_idx_BxTxC = torch.broadcast_to(
119
+ torch.arange(B, device=device).view(B, 1, 1), [B, T, C]
120
+ )
121
+ c_idx_BxTxC = torch.broadcast_to(
122
+ torch.arange(C, device=device).view(1, 1, C), [B, T, C]
123
+ )
124
 
125
  indices_BTCx3 = torch.stack(
126
  [
 
162
  indices_BTCx3 = indices_BTCx3.to(device)
163
 
164
  # Using PyTorch advanced indexing (equivalent to tf.gather_nd or np equivalent)
165
+ gathered_flat = audio_BxTxC[
166
+ indices_BTCx3[:, 0], indices_BTCx3[:, 1], indices_BTCx3[:, 2]
167
+ ]
168
+ gathered_BxTxC = gathered_flat.view(
169
+ audio_BxTxC.size()
170
+ ) # Use .size() for robust reshaping
171
 
172
  # Create pad_tensor on the correct device
173
  pad_tensor = torch.tensor(pad_value, dtype=audio_BxTxC.dtype, device=device)
174
  # Create T tensor on the correct device for comparison
175
  T_tensor = torch.tensor(T, device=device)
176
 
177
+ result_BxTxC = torch.where(
178
+ t_idx_BxTxC >= T_tensor, pad_tensor, gathered_BxTxC
179
+ ) # Changed np.where to torch.where
180
 
181
  return result_BxTxC
182
 
 
201
  except Exception as e:
202
  print(f"Error in decode method: {str(e)}")
203
  raise
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dia/config.py CHANGED
@@ -33,14 +33,20 @@ class DataConfig(BaseModel, frozen=True):
33
  delay_pattern: List of delay values for each audio channel.
34
  """
35
 
36
- text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
37
- audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = Field(gt=0, multiple_of=128)
 
 
 
 
38
  channels: int = Field(default=9, gt=0, multiple_of=1)
39
  text_pad_value: int = Field(default=0)
40
  audio_eos_value: int = Field(default=1024)
41
  audio_pad_value: int = Field(default=1025)
42
  audio_bos_value: int = Field(default=1026)
43
- delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15])
 
 
44
 
45
  def __hash__(self) -> int:
46
  """Generate a hash based on all fields of the config."""
@@ -67,8 +73,6 @@ class EncoderConfig(BaseModel, frozen=True):
67
  n_hidden: Hidden dimension size in the MLP layers.
68
  n_head: Number of attention heads.
69
  head_dim: Dimension per attention head.
70
- mlp_activations: List of activation functions for the MLP layers.
71
- use_pre_norm: Whether to use pre-normalization (LayerNorm before attention/MLP).
72
  """
73
 
74
  n_layer: int = Field(gt=0)
@@ -76,8 +80,6 @@ class EncoderConfig(BaseModel, frozen=True):
76
  n_hidden: int = Field(gt=0)
77
  n_head: int = Field(gt=0)
78
  head_dim: int = Field(gt=0)
79
- mlp_activations: list[str] = Field(default=["silu", "linear"])
80
- use_pre_norm: bool = Field(default=False)
81
 
82
 
83
  class DecoderConfig(BaseModel, frozen=True):
@@ -92,8 +94,6 @@ class DecoderConfig(BaseModel, frozen=True):
92
  gqa_head_dim: Dimension per query head for grouped-query self-attention.
93
  cross_query_heads: Number of query heads for cross-attention.
94
  cross_head_dim: Dimension per cross-attention head.
95
- mlp_activations: List of activation functions for the MLP layers.
96
- use_pre_norm: Whether to use pre-normalization.
97
  """
98
 
99
  n_layer: int = Field(gt=0)
@@ -104,8 +104,6 @@ class DecoderConfig(BaseModel, frozen=True):
104
  gqa_head_dim: int = Field(gt=0)
105
  cross_query_heads: int = Field(gt=0)
106
  cross_head_dim: int = Field(gt=0)
107
- mlp_activations: list[str] = Field(default=["silu", "linear"])
108
- use_pre_norm: bool = Field(default=False)
109
 
110
 
111
  class ModelConfig(BaseModel, frozen=True):
@@ -130,24 +128,16 @@ class ModelConfig(BaseModel, frozen=True):
130
  dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
131
  normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
132
  weight_dtype: str = Field(default="float32", description="Weight precision")
133
- rope_min_timescale: int = Field(default=1, description="Timescale For global Attention")
134
- rope_max_timescale: int = Field(default=10_000, description="Timescale For global Attention")
 
 
 
 
135
 
136
 
137
  class TrainingConfig(BaseModel, frozen=True):
138
- """Training process configuration and hyperparameters.
139
-
140
- Note: This configuration currently only includes precision settings.
141
- Other training parameters (like batch size, learning rate, optimizer settings)
142
- are assumed to be handled externally.
143
-
144
- Attributes:
145
- dtype: Data type for activations during training (e.g., "bfloat16", "float32").
146
- logits_dot_in_fp32: Whether to compute the final logits dot product in fp32 for stability.
147
- """
148
-
149
- dtype: str = Field(default="bfloat16", description="Activation precision")
150
- logits_dot_in_fp32: bool = Field(default=False)
151
 
152
 
153
  class DiaConfig(BaseModel, frozen=True):
@@ -164,6 +154,7 @@ class DiaConfig(BaseModel, frozen=True):
164
 
165
  version: str = Field(default="1.0")
166
  model: ModelConfig
 
167
  training: TrainingConfig
168
  data: DataConfig
169
 
 
33
  delay_pattern: List of delay values for each audio channel.
34
  """
35
 
36
+ text_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
37
+ Field(gt=0, multiple_of=128)
38
+ )
39
+ audio_length: Annotated[int, BeforeValidator(lambda x: (x + 127) // 128 * 128)] = (
40
+ Field(gt=0, multiple_of=128)
41
+ )
42
  channels: int = Field(default=9, gt=0, multiple_of=1)
43
  text_pad_value: int = Field(default=0)
44
  audio_eos_value: int = Field(default=1024)
45
  audio_pad_value: int = Field(default=1025)
46
  audio_bos_value: int = Field(default=1026)
47
+ delay_pattern: list[Annotated[int, Field(ge=0)]] = Field(
48
+ default_factory=lambda: [0, 8, 9, 10, 11, 12, 13, 14, 15]
49
+ )
50
 
51
  def __hash__(self) -> int:
52
  """Generate a hash based on all fields of the config."""
 
73
  n_hidden: Hidden dimension size in the MLP layers.
74
  n_head: Number of attention heads.
75
  head_dim: Dimension per attention head.
 
 
76
  """
77
 
78
  n_layer: int = Field(gt=0)
 
80
  n_hidden: int = Field(gt=0)
81
  n_head: int = Field(gt=0)
82
  head_dim: int = Field(gt=0)
 
 
83
 
84
 
85
  class DecoderConfig(BaseModel, frozen=True):
 
94
  gqa_head_dim: Dimension per query head for grouped-query self-attention.
95
  cross_query_heads: Number of query heads for cross-attention.
96
  cross_head_dim: Dimension per cross-attention head.
 
 
97
  """
98
 
99
  n_layer: int = Field(gt=0)
 
104
  gqa_head_dim: int = Field(gt=0)
105
  cross_query_heads: int = Field(gt=0)
106
  cross_head_dim: int = Field(gt=0)
 
 
107
 
108
 
109
  class ModelConfig(BaseModel, frozen=True):
 
128
  dropout: float = Field(default=0.0, ge=0.0, lt=1.0)
129
  normalization_layer_epsilon: float = Field(default=1.0e-5, ge=0.0)
130
  weight_dtype: str = Field(default="float32", description="Weight precision")
131
+ rope_min_timescale: int = Field(
132
+ default=1, description="Timescale For global Attention"
133
+ )
134
+ rope_max_timescale: int = Field(
135
+ default=10_000, description="Timescale For global Attention"
136
+ )
137
 
138
 
139
  class TrainingConfig(BaseModel, frozen=True):
140
+ pass
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
 
143
  class DiaConfig(BaseModel, frozen=True):
 
154
 
155
  version: str = Field(default="1.0")
156
  model: ModelConfig
157
+ # TODO: remove training. this is just for backwards-compatability
158
  training: TrainingConfig
159
  data: DataConfig
160
 
dia/layers.py CHANGED
@@ -1,5 +1,3 @@
1
- from typing import Any
2
-
3
  import torch
4
  import torch.nn as nn
5
  import torch.nn.functional as F
@@ -7,26 +5,13 @@ from torch import Tensor
7
  from torch.nn import RMSNorm
8
 
9
  from .config import DiaConfig
 
10
 
11
 
12
  def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
13
  return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
14
 
15
 
16
- def _str_to_dtype(dtype_str: str) -> torch.dtype | None:
17
- # Allow None for default behavior
18
- if dtype_str is None or dtype_str.lower() == "none":
19
- return None
20
- if dtype_str == "float32":
21
- return torch.float32
22
- elif dtype_str == "float16":
23
- return torch.float16
24
- elif dtype_str == "bfloat16":
25
- return torch.bfloat16
26
- else:
27
- raise ValueError(f"Unsupported dtype string: {dtype_str}")
28
-
29
-
30
  class DenseGeneral(nn.Module):
31
  """
32
  PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
@@ -50,7 +35,6 @@ class DenseGeneral(nn.Module):
50
  in_shapes: tuple[int, ...],
51
  out_features: tuple[int, ...],
52
  axis: tuple[int, ...] = (-1,),
53
- dtype: torch.dtype | None = None,
54
  weight_dtype: torch.dtype | None = None,
55
  device: torch.device | None = None,
56
  ):
@@ -58,7 +42,6 @@ class DenseGeneral(nn.Module):
58
  self.in_shapes = in_shapes
59
  self.out_features = out_features
60
  self.axis = axis
61
- self.dtype = dtype
62
  self.kernel_shape = self.in_shapes + self.out_features
63
 
64
  factory_kwargs = {"device": device, "dtype": weight_dtype}
@@ -70,95 +53,44 @@ class DenseGeneral(nn.Module):
70
  kernel_contract_axes = tuple(range(len(norm_axis)))
71
 
72
  output = torch.tensordot(
73
- inputs.float(),
74
- self.weight.float(),
75
  dims=(norm_axis, kernel_contract_axes),
76
  ).to(inputs.dtype)
77
  return output
78
 
79
 
80
- def get_activation_fn(activation_string: str) -> nn.Module: # Return Module instance
81
- """Maps activation string to PyTorch activation function module."""
82
- if activation_string == "gelu":
83
- return nn.GELU()
84
- elif activation_string == "relu":
85
- return nn.ReLU()
86
- elif activation_string == "silu" or activation_string == "swish":
87
- return nn.SiLU()
88
- elif activation_string == "linear":
89
- return nn.Identity()
90
- else:
91
- raise ValueError(f"Unsupported activation function: {activation_string}")
92
-
93
-
94
  class MlpBlock(nn.Module):
95
  """MLP block using DenseGeneral."""
96
 
97
  def __init__(
98
- self,
99
- config: DiaConfig,
100
- embed_dim: int,
101
- intermediate_dim: int,
102
- dropout_rate: float,
103
- activations: list[str] = ["silu", "linear"],
104
- use_pre_norm: bool = False,
105
  ):
106
  super().__init__()
107
- self.use_pre_norm = use_pre_norm
108
- num_activations = len(activations)
109
- compute_dtype = _str_to_dtype(config.training.dtype)
110
- weight_dtype = _str_to_dtype(config.model.weight_dtype)
111
  self.dtype = compute_dtype
112
- # Assume default device for now, could be passed in config
113
-
114
- if use_pre_norm:
115
- self.pre_norm = RMSNorm(
116
- embed_dim,
117
- eps=config.model.normalization_layer_epsilon,
118
- dtype=torch.float32,
119
- )
120
 
121
  self.wi_fused = DenseGeneral(
122
  in_shapes=(embed_dim,),
123
- out_features=(
124
- num_activations,
125
- intermediate_dim,
126
- ),
127
  axis=(-1,),
128
- dtype=compute_dtype,
129
- weight_dtype=weight_dtype,
130
  )
131
 
132
- self.activation_fn_0 = get_activation_fn(activations[0]) # silu
133
- self.activation_fn_1 = get_activation_fn(activations[1]) # linear
134
-
135
- self.dropout = nn.Dropout(dropout_rate)
136
-
137
- # Output layer using DenseGeneral
138
  self.wo = DenseGeneral(
139
  in_shapes=(intermediate_dim,),
140
  out_features=(embed_dim,),
141
  axis=(-1,),
142
- dtype=compute_dtype,
143
- weight_dtype=weight_dtype,
144
  )
145
 
146
- def forward(self, x: torch.Tensor, deterministic: bool) -> torch.Tensor:
147
  """Forward pass."""
148
- if self.use_pre_norm and hasattr(self, "pre_norm"):
149
- x = self.pre_norm(x)
150
-
151
  fused_x = self.wi_fused(x)
152
 
153
- gate_input = fused_x[..., 0, :]
154
- up_input = fused_x[..., 1, :]
155
-
156
- gate = self.activation_fn_0(gate_input)
157
- up = self.activation_fn_1(up_input)
158
- hidden = torch.mul(gate, up).to(self.dtype)
159
 
160
- if not deterministic:
161
- hidden = self.dropout(hidden)
162
 
163
  output = self.wo(hidden)
164
  return output
@@ -207,37 +139,6 @@ class RotaryEmbedding(nn.Module):
207
  return torch.cat((first_part, second_part), dim=-1)
208
 
209
 
210
- class KVCache:
211
- def __init__(self, num_heads, max_len, head_dim, device, k=None, v=None):
212
- self.k = torch.zeros((2, num_heads, max_len, head_dim), device=device) if k is None else k
213
- self.v = torch.zeros((2, num_heads, max_len, head_dim), device=device) if v is None else v
214
- self.current_idx = 0
215
- self.max_len = max_len
216
-
217
- def get_kv_for_attention(self, current_k, current_v):
218
- if self.current_idx == 0:
219
- return current_k, current_v
220
- else:
221
- past_k = self.k[:, :, : self.current_idx, :]
222
- past_v = self.v[:, :, : self.current_idx, :]
223
- attn_k = torch.cat((past_k, current_k), dim=2)
224
- attn_v = torch.cat((past_v, current_v), dim=2)
225
- return attn_k, attn_v
226
-
227
- def update_cache(self, k, v):
228
- assert self.current_idx < self.max_len
229
- self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
230
- self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
231
- self.current_idx += 1
232
-
233
- def prefill_kv(self, k, v):
234
- prefill_len = k.shape[2]
235
- assert prefill_len <= self.max_len
236
- self.k[:, :, :prefill_len, :] = k
237
- self.v[:, :, :prefill_len, :] = v
238
- self.current_idx = prefill_len
239
-
240
-
241
  class Attention(nn.Module):
242
  """Attention using DenseGeneral."""
243
 
@@ -249,7 +150,7 @@ class Attention(nn.Module):
249
  num_query_heads: int,
250
  num_kv_heads: int,
251
  head_dim: int,
252
- dropout_rate: float,
253
  is_cross_attn: bool = False,
254
  out_embed_dim: int | None = None,
255
  ):
@@ -258,13 +159,12 @@ class Attention(nn.Module):
258
  self.num_kv_heads = num_kv_heads
259
  self.head_dim = head_dim
260
  self.is_cross_attn = is_cross_attn
261
- self.dropout_rate = dropout_rate
262
- compute_dtype = _str_to_dtype(config.training.dtype)
263
- weight_dtype = _str_to_dtype(config.model.weight_dtype)
264
  self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
265
  self.projected_query_dim = num_query_heads * head_dim
266
  if num_query_heads % num_kv_heads != 0:
267
- raise ValueError(f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})")
 
 
268
  self.num_gqa_groups = num_query_heads // num_kv_heads
269
 
270
  # --- Projection Layers using DenseGeneral ---
@@ -272,29 +172,25 @@ class Attention(nn.Module):
272
  in_shapes=(q_embed_dim,),
273
  out_features=(num_query_heads, head_dim),
274
  axis=(-1,),
275
- dtype=compute_dtype,
276
- weight_dtype=weight_dtype,
277
  )
278
  self.k_proj = DenseGeneral(
279
  in_shapes=(kv_embed_dim,),
280
  out_features=(num_kv_heads, head_dim),
281
  axis=(-1,),
282
- dtype=compute_dtype,
283
- weight_dtype=weight_dtype,
284
  )
285
  self.v_proj = DenseGeneral(
286
  in_shapes=(kv_embed_dim,),
287
  out_features=(num_kv_heads, head_dim),
288
  axis=(-1,),
289
- dtype=compute_dtype,
290
- weight_dtype=weight_dtype,
291
  )
292
  self.o_proj = DenseGeneral(
293
  in_shapes=(num_query_heads, head_dim),
294
  out_features=(self.output_dim,),
295
  axis=(-2, -1),
296
- dtype=compute_dtype,
297
- weight_dtype=weight_dtype,
298
  )
299
 
300
  # --- Rotary Embedding ---
@@ -311,10 +207,11 @@ class Attention(nn.Module):
311
  Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
312
  q_positions: torch.Tensor, # (B, T)
313
  kv_positions: torch.Tensor | None = None, # (B, S)
314
- deterministic: bool = True,
315
- attn_mask: torch.Tensor | None = None, # None in Decoder Self Attention, Valid mask in Others
316
  cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
317
- prefill: bool = False, # True only when prefilling KV Cache
 
318
  ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
319
  """
320
  Performs attention calculation with optional KV caching.
@@ -324,7 +221,6 @@ class Attention(nn.Module):
324
  Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
325
  q_positions: Positions for queries (B, T).
326
  kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
327
- deterministic: If True, disable dropout.
328
  attn_mask: Attention mask.
329
  cache: KVCache.
330
  prefill: If True, use prefill mode.
@@ -342,72 +238,51 @@ class Attention(nn.Module):
342
  Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
343
  Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
344
 
345
- # Input values into attention calculation
346
  attn_k: torch.Tensor | None = None
347
  attn_v: torch.Tensor | None = None
348
- new_kv_cache: tuple[torch.Tensor, torch.Tensor] | None = None
349
 
350
- # Decoder Cross Attention
351
  if self.is_cross_attn:
352
- # Directly use cache (no need to check index)
353
  attn_k, attn_v = cache.k, cache.v
354
- if attn_k.shape[1] != self.num_query_heads or attn_v.shape[1] != self.num_query_heads:
355
- raise ValueError(
356
- f"Cross-attention cache head dimension ({attn_k.shape[1]}) "
357
- f"does not match num_query_heads ({self.num_query_heads}). "
358
- "Cache should be pre-repeated for GQA."
359
- )
360
- # Self Attention
361
  else:
362
  Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
363
  Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
364
- Xk_BxSxKxH = self.rotary_emb(Xk_BxSxKxH, position=kv_positions) # (B, S, K, H)
 
 
365
 
366
  Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
367
  Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
368
- # S=1 for Decode Step
369
-
370
- if self.num_gqa_groups > 1:
371
- Xk_BxNxSxH = Xk_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
372
- Xv_BxNxSxH = Xv_BxKxSxH.repeat_interleave(self.num_gqa_groups, dim=1)
373
- else:
374
- Xk_BxNxSxH = Xk_BxKxSxH
375
- Xv_BxNxSxH = Xv_BxKxSxH
376
 
377
- # Encoder Self Attention
378
  if cache is None:
379
- attn_k = Xk_BxNxSxH
380
- attn_v = Xv_BxNxSxH
381
- # Decoder Self Attention
382
  else:
383
- # In prefill mode, we fill in cache until prefill length
384
  if prefill:
385
- attn_k, attn_v = Xk_BxNxSxH, Xv_BxNxSxH
386
- cache.prefill_kv(attn_k, attn_v)
387
- # In decode step, we add current K/V to cache step by step
388
  else:
389
- new_kv_cache = Xk_BxNxSxH, Xv_BxNxSxH
390
- attn_k, attn_v = cache.get_kv_for_attention(Xk_BxNxSxH, Xv_BxNxSxH)
391
 
392
  attn_output = F.scaled_dot_product_attention(
393
  Xq_BxNxTxH,
394
  attn_k,
395
  attn_v,
396
  attn_mask=attn_mask,
397
- dropout_p=self.dropout_rate if not deterministic else 0.0,
398
  scale=1.0,
 
 
399
  )
400
 
401
  attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
402
  output = self.o_proj(attn_output)
403
 
404
- return output.to(original_dtype), new_kv_cache
405
 
406
 
407
  class EncoderLayer(nn.Module):
408
  """Transformer Encoder Layer using DenseGeneral."""
409
 
410
- def __init__(self, config: DiaConfig):
411
  super().__init__()
412
  self.config = config
413
  model_config = config.model
@@ -420,13 +295,13 @@ class EncoderLayer(nn.Module):
420
  dtype=torch.float32,
421
  )
422
  self.self_attention = Attention(
423
- config=config,
424
  q_embed_dim=embed_dim,
425
  kv_embed_dim=embed_dim,
426
  num_query_heads=enc_config.n_head,
427
  num_kv_heads=enc_config.n_head,
428
  head_dim=enc_config.head_dim,
429
- dropout_rate=model_config.dropout,
430
  is_cross_attn=False,
431
  out_embed_dim=embed_dim,
432
  )
@@ -436,62 +311,52 @@ class EncoderLayer(nn.Module):
436
  dtype=torch.float32,
437
  )
438
  self.mlp = MlpBlock(
439
- config=config,
440
  embed_dim=embed_dim,
441
  intermediate_dim=enc_config.n_hidden,
442
- activations=enc_config.mlp_activations,
443
- dropout_rate=model_config.dropout,
444
- use_pre_norm=enc_config.use_pre_norm,
445
  )
446
- self.dropout = nn.Dropout(model_config.dropout)
447
 
448
  def forward(
449
  self,
450
  x: torch.Tensor,
451
- src_positions: torch.Tensor | None = None,
452
- deterministic: bool = True,
453
- attn_mask: torch.Tensor | None = None,
454
  ) -> torch.Tensor:
455
  residual = x
456
  x_norm = self.pre_sa_norm(x)
457
-
458
- sa_out, _ = self.self_attention(
459
  Xq=x_norm,
460
  Xkv=x_norm,
461
- q_positions=src_positions,
462
- kv_positions=src_positions,
463
- deterministic=deterministic,
464
- attn_mask=attn_mask,
465
  )
466
  x = residual + sa_out
467
 
468
  residual = x
469
  x_norm = self.post_sa_norm(x)
470
- mlp_out = self.mlp(x_norm, deterministic=deterministic)
471
  x = residual + mlp_out
472
 
473
- if not deterministic:
474
- x = self.dropout(x)
475
  return x
476
 
477
 
478
  class Encoder(nn.Module):
479
  """Transformer Encoder Stack using DenseGeneral."""
480
 
481
- def __init__(self, config: DiaConfig):
482
  super().__init__()
483
  self.config = config
484
  model_config = config.model
485
  enc_config = config.model.encoder
486
- compute_dtype = _str_to_dtype(config.training.dtype)
487
 
488
  self.embedding = nn.Embedding(
489
  model_config.src_vocab_size,
490
  enc_config.n_embd,
491
  dtype=compute_dtype,
492
  )
493
- self.dropout = nn.Dropout(model_config.dropout)
494
- self.layers = nn.ModuleList([EncoderLayer(config=config) for _ in range(enc_config.n_layer)])
 
495
  self.norm = RMSNorm(
496
  enc_config.n_embd,
497
  eps=model_config.normalization_layer_epsilon,
@@ -501,32 +366,21 @@ class Encoder(nn.Module):
501
  def forward(
502
  self,
503
  x_ids: torch.Tensor,
504
- src_positions: torch.Tensor | None = None,
505
- deterministic: bool = True,
506
- attn_mask: torch.Tensor | None = None,
507
  ) -> torch.Tensor:
508
  x = self.embedding(x_ids)
509
 
510
- if not deterministic:
511
- x = self.dropout(x)
512
-
513
  for layer in self.layers:
514
- x = layer(
515
- x,
516
- src_positions=src_positions,
517
- deterministic=deterministic,
518
- attn_mask=attn_mask,
519
- )
520
  x = self.norm(x)
521
- if not deterministic:
522
- x = self.dropout(x)
523
  return x
524
 
525
 
526
  class DecoderLayer(nn.Module):
527
  """Transformer Decoder Layer using DenseGeneral."""
528
 
529
- def __init__(self, config: DiaConfig):
530
  super().__init__()
531
  self.config = config
532
  model_config = config.model
@@ -554,13 +408,13 @@ class DecoderLayer(nn.Module):
554
 
555
  # Self-Attention (GQA) with Causal Masking
556
  self.self_attention = Attention(
557
- config=config,
558
  q_embed_dim=dec_embed_dim,
559
  kv_embed_dim=dec_embed_dim,
560
  num_query_heads=dec_config.gqa_query_heads,
561
  num_kv_heads=dec_config.kv_heads,
562
  head_dim=dec_config.gqa_head_dim,
563
- dropout_rate=model_config.dropout,
564
  is_cross_attn=False,
565
  out_embed_dim=dec_embed_dim,
566
  )
@@ -572,116 +426,105 @@ class DecoderLayer(nn.Module):
572
  num_query_heads=dec_config.cross_query_heads,
573
  num_kv_heads=dec_config.cross_query_heads,
574
  head_dim=dec_config.cross_head_dim,
575
- dropout_rate=model_config.dropout,
576
  is_cross_attn=True,
577
  out_embed_dim=dec_embed_dim,
578
  )
579
  # MLP
580
  self.mlp = MlpBlock(
581
- config=config,
582
  embed_dim=dec_embed_dim,
583
  intermediate_dim=dec_config.n_hidden,
584
- activations=dec_config.mlp_activations,
585
- dropout_rate=model_config.dropout,
586
- use_pre_norm=dec_config.use_pre_norm,
587
  )
588
 
589
  def forward(
590
  self,
591
  x: torch.Tensor,
592
- encoder_out: torch.Tensor,
593
- tgt_positions: torch.Tensor,
594
- src_positions: torch.Tensor | None,
595
- deterministic: bool,
596
- self_attn_mask: torch.Tensor,
597
- cross_attn_mask: torch.Tensor,
598
- self_attn_cache: KVCache,
599
- cross_attn_cache: KVCache,
600
  prefill: bool = False,
601
  ) -> torch.Tensor:
602
  residual = x
603
  x_norm = self.pre_sa_norm(x)
604
 
605
- sa_out, new_kv_cache = self.self_attention(
606
  Xq=x_norm, # (2, 1, D)
607
  Xkv=x_norm, # (2, 1, D)
608
- q_positions=tgt_positions, # (2, 1)
609
- kv_positions=tgt_positions, # (2, 1)
610
- deterministic=deterministic,
611
- attn_mask=self_attn_mask, # (2, 1, 1, S_max)
612
  cache=self_attn_cache,
613
  prefill=prefill,
 
614
  )
615
 
616
  x = residual + sa_out
617
 
618
- # 2. Cross-Attention
619
  residual = x
620
  x_norm = self.pre_ca_norm(x)
621
- ca_out, _ = self.cross_attention(
622
  Xq=x_norm,
623
- Xkv=encoder_out,
624
- q_positions=tgt_positions,
625
- kv_positions=src_positions,
626
- deterministic=deterministic,
627
- attn_mask=cross_attn_mask,
628
  cache=cross_attn_cache,
629
  )
630
  x = residual + ca_out
631
 
632
- # 3. MLP
633
  residual = x
634
  x_norm = self.pre_mlp_norm(x)
635
- mlp_out = self.mlp(x_norm, deterministic=deterministic)
636
  x = residual + mlp_out
637
 
638
- return x, new_kv_cache
639
 
640
 
641
  class Decoder(nn.Module):
642
  """Transformer Decoder Stack using DenseGeneral."""
643
 
644
- def __init__(self, config: DiaConfig):
645
  super().__init__()
646
  self.config = config
647
  model_config = config.model
648
  dec_config = config.model.decoder
649
- train_config = config.training
650
  data_config = config.data
651
- compute_dtype = _str_to_dtype(config.training.dtype)
652
- weight_dtype = _str_to_dtype(config.model.weight_dtype)
653
  self.num_channels = data_config.channels
654
  self.num_layers = dec_config.n_layer
655
 
656
  self.embeddings = nn.ModuleList(
657
  [
658
- nn.Embedding(model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype)
 
 
659
  for _ in range(self.num_channels)
660
  ]
661
  )
662
- self.dropout = nn.Dropout(model_config.dropout)
663
- self.layers = nn.ModuleList([DecoderLayer(config=config) for _ in range(self.num_layers)])
 
 
 
 
 
664
  self.norm = RMSNorm(
665
  dec_config.n_embd,
666
  eps=model_config.normalization_layer_epsilon,
667
  dtype=torch.float32,
668
  )
669
 
670
- # Final Logits Projection using DenseGeneral
671
  self.logits_dense = DenseGeneral(
672
  in_shapes=(dec_config.n_embd,),
673
  out_features=(self.num_channels, model_config.tgt_vocab_size),
674
  axis=(-1,),
675
- dtype=(torch.float32 if train_config.logits_dot_in_fp32 else compute_dtype),
676
- weight_dtype=weight_dtype,
677
  )
678
- self.logits_in_fp32 = train_config.logits_dot_in_fp32
679
 
680
- def precompute_cross_attention_kv(
681
  self,
682
- max_len: int,
683
- encoder_out: torch.Tensor, # (B, S, E)
684
- src_positions: torch.Tensor | None, # (B, S)
685
  ) -> list[KVCache]:
686
  """
687
  Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
@@ -690,35 +533,21 @@ class Decoder(nn.Module):
690
 
691
  for layer in self.layers:
692
  cross_attn_module = layer.cross_attention
693
- k_proj = cross_attn_module.k_proj(encoder_out)
694
- v_proj = cross_attn_module.v_proj(encoder_out)
695
 
696
- k_proj = cross_attn_module.rotary_emb(k_proj, position=src_positions)
697
  k = k_proj.transpose(1, 2)
698
  v = v_proj.transpose(1, 2)
699
 
700
- per_layer_kv_cache.append(
701
- KVCache(
702
- cross_attn_module.num_kv_heads,
703
- max_len,
704
- cross_attn_module.head_dim,
705
- k.device,
706
- k=k,
707
- v=v,
708
- )
709
- )
710
 
711
  return per_layer_kv_cache
712
 
713
  def decode_step(
714
  self,
715
  tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
716
- tgt_pos_Bx1: torch.Tensor, # [B, 1]
717
- encoder_out: torch.Tensor, # [B, S, E]
718
- self_attn_mask: Any, # None
719
- cross_attn_mask: torch.Tensor, # [B, 1, 1, S]
720
- self_attention_cache: list[KVCache],
721
- cross_attention_cache: list[KVCache],
722
  ) -> torch.Tensor:
723
  """
724
  Performs a single decoding step, managing KV caches layer by layer.
@@ -727,7 +556,6 @@ class Decoder(nn.Module):
727
  A tuple containing:
728
  - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
729
  """
730
- assert self_attn_mask is None, "Self-attention mask should be None, kept for pattern"
731
 
732
  x = None
733
  for i in range(self.num_channels):
@@ -735,40 +563,23 @@ class Decoder(nn.Module):
735
  channel_embed = self.embeddings[i](channel_tokens)
736
  x = channel_embed if x is None else x + channel_embed
737
 
738
- new_cache = []
739
-
740
  for i, layer in enumerate(self.layers):
741
- self_cache = self_attention_cache[i]
742
- cross_cache = cross_attention_cache[i]
743
- x, new_kv_cache = layer(
744
  x, # (2, 1, D)
745
- encoder_out, # (2, S, E)
746
- src_positions=None, # CA KV is already computed
747
- tgt_positions=tgt_pos_Bx1, # (2, 1)
748
- deterministic=True,
749
- self_attn_mask=None,
750
- cross_attn_mask=cross_attn_mask,
751
  self_attn_cache=self_cache,
752
  cross_attn_cache=cross_cache,
753
  )
754
- new_cache.append(new_kv_cache)
755
 
756
  x = self.norm(x)
757
  logits_Bx1xCxV = self.logits_dense(x)
758
 
759
- return logits_Bx1xCxV.to(torch.float32), new_cache
760
 
761
  def forward(
762
- self,
763
- tgt_ids_BxTxC: torch.Tensor,
764
- encoder_out: torch.Tensor,
765
- tgt_positions: torch.Tensor,
766
- src_positions: torch.Tensor,
767
- deterministic: bool,
768
- self_attn_mask: torch.Tensor,
769
- cross_attn_mask: torch.Tensor,
770
- self_attention_cache: list[KVCache],
771
- cross_attention_cache: list[KVCache],
772
  ) -> torch.Tensor:
773
  """
774
  Forward pass for the Decoder stack, managing KV caches.
@@ -778,7 +589,6 @@ class Decoder(nn.Module):
778
  encoder_out: Output from the encoder (B, S, E).
779
  tgt_positions: Positions for target sequence (B, T).
780
  src_positions: Positions for source sequence (B, S).
781
- deterministic: Disable dropout if True.
782
  self_attn_mask: Mask for self-attention.
783
  cross_attn_mask: Mask for cross-attention.
784
  past_key_values: List containing the self-attention KV cache for each layer
@@ -804,20 +614,14 @@ class Decoder(nn.Module):
804
  channel_embed = self.embeddings[i](channel_tokens)
805
  x = channel_embed if x is None else x + channel_embed
806
 
807
- if not deterministic:
808
- x = self.dropout(x)
809
-
810
  for i, layer in enumerate(self.layers):
811
- x, _ = layer(
 
 
812
  x,
813
- encoder_out,
814
- tgt_positions=tgt_positions,
815
- src_positions=src_positions,
816
- deterministic=deterministic,
817
- self_attn_mask=self_attn_mask,
818
- cross_attn_mask=cross_attn_mask,
819
- self_attn_cache=self_attention_cache[i],
820
- cross_attn_cache=cross_attention_cache[i],
821
  prefill=True,
822
  )
823
 
@@ -831,43 +635,8 @@ class Decoder(nn.Module):
831
  class DiaModel(nn.Module):
832
  """PyTorch Dia Model using DenseGeneral."""
833
 
834
- def __init__(self, config: DiaConfig):
835
  super().__init__()
836
  self.config = config
837
- self.encoder = Encoder(config)
838
- self.decoder = Decoder(config)
839
-
840
- def forward(
841
- self,
842
- src_BxS: torch.Tensor,
843
- tgt_BxTxC: torch.Tensor,
844
- src_positions: torch.Tensor | None = None,
845
- tgt_positions: torch.Tensor | None = None,
846
- enc_self_attn_mask: torch.Tensor | None = None,
847
- dec_self_attn_mask: torch.Tensor | None = None,
848
- dec_cross_attn_mask: torch.Tensor | None = None,
849
- enable_dropout: bool = True,
850
- ):
851
- deterministic = not enable_dropout
852
-
853
- # --- Encoder Pass ---
854
- encoder_out = self.encoder(
855
- x_ids=src_BxS,
856
- src_positions=src_positions,
857
- deterministic=deterministic,
858
- attn_mask=enc_self_attn_mask,
859
- )
860
-
861
- # --- Decoder Pass ---
862
- logits, _ = self.decoder(
863
- tgt_ids_BxTxC=tgt_BxTxC,
864
- encoder_out=encoder_out,
865
- tgt_positions=tgt_positions,
866
- src_positions=src_positions,
867
- deterministic=deterministic,
868
- self_attn_mask=dec_self_attn_mask,
869
- cross_attn_mask=dec_cross_attn_mask,
870
- precomputed_cross_attn_kv=None,
871
- )
872
-
873
- return logits
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
5
  from torch.nn import RMSNorm
6
 
7
  from .config import DiaConfig
8
+ from .state import DecoderInferenceState, EncoderInferenceState, KVCache
9
 
10
 
11
  def _normalize_axes(axes: tuple[int, ...], ndim: int) -> tuple[int, ...]:
12
  return tuple(ax if ax >= 0 else ndim + ax for ax in axes)
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  class DenseGeneral(nn.Module):
16
  """
17
  PyTorch equivalent of flax.linen.DenseGeneral with shapes defined at init.
 
35
  in_shapes: tuple[int, ...],
36
  out_features: tuple[int, ...],
37
  axis: tuple[int, ...] = (-1,),
 
38
  weight_dtype: torch.dtype | None = None,
39
  device: torch.device | None = None,
40
  ):
 
42
  self.in_shapes = in_shapes
43
  self.out_features = out_features
44
  self.axis = axis
 
45
  self.kernel_shape = self.in_shapes + self.out_features
46
 
47
  factory_kwargs = {"device": device, "dtype": weight_dtype}
 
53
  kernel_contract_axes = tuple(range(len(norm_axis)))
54
 
55
  output = torch.tensordot(
56
+ inputs.to(self.weight.dtype),
57
+ self.weight,
58
  dims=(norm_axis, kernel_contract_axes),
59
  ).to(inputs.dtype)
60
  return output
61
 
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  class MlpBlock(nn.Module):
64
  """MLP block using DenseGeneral."""
65
 
66
  def __init__(
67
+ self, embed_dim: int, intermediate_dim: int, compute_dtype: torch.dtype
 
 
 
 
 
 
68
  ):
69
  super().__init__()
 
 
 
 
70
  self.dtype = compute_dtype
 
 
 
 
 
 
 
 
71
 
72
  self.wi_fused = DenseGeneral(
73
  in_shapes=(embed_dim,),
74
+ out_features=(2, intermediate_dim),
 
 
 
75
  axis=(-1,),
76
+ weight_dtype=compute_dtype,
 
77
  )
78
 
 
 
 
 
 
 
79
  self.wo = DenseGeneral(
80
  in_shapes=(intermediate_dim,),
81
  out_features=(embed_dim,),
82
  axis=(-1,),
83
+ weight_dtype=compute_dtype,
 
84
  )
85
 
86
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
87
  """Forward pass."""
 
 
 
88
  fused_x = self.wi_fused(x)
89
 
90
+ gate = fused_x[..., 0, :]
91
+ up = fused_x[..., 1, :]
 
 
 
 
92
 
93
+ hidden = torch.mul(F.silu(gate), up).to(self.dtype)
 
94
 
95
  output = self.wo(hidden)
96
  return output
 
139
  return torch.cat((first_part, second_part), dim=-1)
140
 
141
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  class Attention(nn.Module):
143
  """Attention using DenseGeneral."""
144
 
 
150
  num_query_heads: int,
151
  num_kv_heads: int,
152
  head_dim: int,
153
+ compute_dtype: torch.dtype,
154
  is_cross_attn: bool = False,
155
  out_embed_dim: int | None = None,
156
  ):
 
159
  self.num_kv_heads = num_kv_heads
160
  self.head_dim = head_dim
161
  self.is_cross_attn = is_cross_attn
 
 
 
162
  self.output_dim = out_embed_dim if out_embed_dim is not None else q_embed_dim
163
  self.projected_query_dim = num_query_heads * head_dim
164
  if num_query_heads % num_kv_heads != 0:
165
+ raise ValueError(
166
+ f"num_query_heads ({num_query_heads}) must be divisible by num_kv_heads ({num_kv_heads})"
167
+ )
168
  self.num_gqa_groups = num_query_heads // num_kv_heads
169
 
170
  # --- Projection Layers using DenseGeneral ---
 
172
  in_shapes=(q_embed_dim,),
173
  out_features=(num_query_heads, head_dim),
174
  axis=(-1,),
175
+ weight_dtype=compute_dtype,
 
176
  )
177
  self.k_proj = DenseGeneral(
178
  in_shapes=(kv_embed_dim,),
179
  out_features=(num_kv_heads, head_dim),
180
  axis=(-1,),
181
+ weight_dtype=compute_dtype,
 
182
  )
183
  self.v_proj = DenseGeneral(
184
  in_shapes=(kv_embed_dim,),
185
  out_features=(num_kv_heads, head_dim),
186
  axis=(-1,),
187
+ weight_dtype=compute_dtype,
 
188
  )
189
  self.o_proj = DenseGeneral(
190
  in_shapes=(num_query_heads, head_dim),
191
  out_features=(self.output_dim,),
192
  axis=(-2, -1),
193
+ weight_dtype=compute_dtype,
 
194
  )
195
 
196
  # --- Rotary Embedding ---
 
207
  Xkv: torch.Tensor, # (B, S, E) S = 1 in AR generation
208
  q_positions: torch.Tensor, # (B, T)
209
  kv_positions: torch.Tensor | None = None, # (B, S)
210
+ attn_mask: torch.Tensor
211
+ | None = None, # None in Decoder Self Attention, Valid mask in Others
212
  cache: KVCache | None = None, # None in Encoder, KVCache in Decoder
213
+ prefill: bool = False,
214
+ is_causal: bool = False,
215
  ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor] | None]:
216
  """
217
  Performs attention calculation with optional KV caching.
 
221
  Xkv: Key/Value source tensor (B, S, E). S=1 during single-step decoding for self-attn.
222
  q_positions: Positions for queries (B, T).
223
  kv_positions: Positions for keys/values (B, S). If None, uses q_positions.
 
224
  attn_mask: Attention mask.
225
  cache: KVCache.
226
  prefill: If True, use prefill mode.
 
238
  Xq_BxTxNxH = self.rotary_emb(Xq_BxTxNxH, position=q_positions)
239
  Xq_BxNxTxH = Xq_BxTxNxH.transpose(1, 2)
240
 
 
241
  attn_k: torch.Tensor | None = None
242
  attn_v: torch.Tensor | None = None
 
243
 
 
244
  if self.is_cross_attn:
 
245
  attn_k, attn_v = cache.k, cache.v
 
 
 
 
 
 
 
246
  else:
247
  Xk_BxSxKxH = self.k_proj(Xkv) # (B, S, K, H)
248
  Xv_BxSxKxH = self.v_proj(Xkv) # (B, S, K, H)
249
+ Xk_BxSxKxH = self.rotary_emb(
250
+ Xk_BxSxKxH, position=kv_positions
251
+ ) # (B, S, K, H)
252
 
253
  Xk_BxKxSxH = Xk_BxSxKxH.transpose(1, 2) # (B, K, S, H)
254
  Xv_BxKxSxH = Xv_BxSxKxH.transpose(1, 2) # (B, K, S, H)
 
 
 
 
 
 
 
 
255
 
 
256
  if cache is None:
257
+ attn_k = Xk_BxKxSxH
258
+ attn_v = Xv_BxKxSxH
 
259
  else:
 
260
  if prefill:
261
+ attn_k, attn_v = Xk_BxKxSxH, Xv_BxKxSxH
262
+ cache.prefill(attn_k, attn_v)
 
263
  else:
264
+ attn_k, attn_v = cache.update(Xk_BxKxSxH, Xv_BxKxSxH)
 
265
 
266
  attn_output = F.scaled_dot_product_attention(
267
  Xq_BxNxTxH,
268
  attn_k,
269
  attn_v,
270
  attn_mask=attn_mask,
 
271
  scale=1.0,
272
+ enable_gqa=self.num_gqa_groups > 1,
273
+ is_causal=is_causal,
274
  )
275
 
276
  attn_output = attn_output.transpose(1, 2).contiguous() # (B, T, N, H)
277
  output = self.o_proj(attn_output)
278
 
279
+ return output.to(original_dtype)
280
 
281
 
282
  class EncoderLayer(nn.Module):
283
  """Transformer Encoder Layer using DenseGeneral."""
284
 
285
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
286
  super().__init__()
287
  self.config = config
288
  model_config = config.model
 
295
  dtype=torch.float32,
296
  )
297
  self.self_attention = Attention(
298
+ config,
299
  q_embed_dim=embed_dim,
300
  kv_embed_dim=embed_dim,
301
  num_query_heads=enc_config.n_head,
302
  num_kv_heads=enc_config.n_head,
303
  head_dim=enc_config.head_dim,
304
+ compute_dtype=compute_dtype,
305
  is_cross_attn=False,
306
  out_embed_dim=embed_dim,
307
  )
 
311
  dtype=torch.float32,
312
  )
313
  self.mlp = MlpBlock(
 
314
  embed_dim=embed_dim,
315
  intermediate_dim=enc_config.n_hidden,
316
+ compute_dtype=compute_dtype,
 
 
317
  )
 
318
 
319
  def forward(
320
  self,
321
  x: torch.Tensor,
322
+ state: EncoderInferenceState,
 
 
323
  ) -> torch.Tensor:
324
  residual = x
325
  x_norm = self.pre_sa_norm(x)
326
+ sa_out = self.self_attention(
 
327
  Xq=x_norm,
328
  Xkv=x_norm,
329
+ q_positions=state.positions,
330
+ kv_positions=state.positions,
331
+ attn_mask=state.attn_mask,
 
332
  )
333
  x = residual + sa_out
334
 
335
  residual = x
336
  x_norm = self.post_sa_norm(x)
337
+ mlp_out = self.mlp(x_norm)
338
  x = residual + mlp_out
339
 
 
 
340
  return x
341
 
342
 
343
  class Encoder(nn.Module):
344
  """Transformer Encoder Stack using DenseGeneral."""
345
 
346
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
347
  super().__init__()
348
  self.config = config
349
  model_config = config.model
350
  enc_config = config.model.encoder
 
351
 
352
  self.embedding = nn.Embedding(
353
  model_config.src_vocab_size,
354
  enc_config.n_embd,
355
  dtype=compute_dtype,
356
  )
357
+ self.layers = nn.ModuleList(
358
+ [EncoderLayer(config, compute_dtype) for _ in range(enc_config.n_layer)]
359
+ )
360
  self.norm = RMSNorm(
361
  enc_config.n_embd,
362
  eps=model_config.normalization_layer_epsilon,
 
366
  def forward(
367
  self,
368
  x_ids: torch.Tensor,
369
+ state: EncoderInferenceState,
 
 
370
  ) -> torch.Tensor:
371
  x = self.embedding(x_ids)
372
 
 
 
 
373
  for layer in self.layers:
374
+ x = layer(x, state)
375
+
 
 
 
 
376
  x = self.norm(x)
 
 
377
  return x
378
 
379
 
380
  class DecoderLayer(nn.Module):
381
  """Transformer Decoder Layer using DenseGeneral."""
382
 
383
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
384
  super().__init__()
385
  self.config = config
386
  model_config = config.model
 
408
 
409
  # Self-Attention (GQA) with Causal Masking
410
  self.self_attention = Attention(
411
+ config,
412
  q_embed_dim=dec_embed_dim,
413
  kv_embed_dim=dec_embed_dim,
414
  num_query_heads=dec_config.gqa_query_heads,
415
  num_kv_heads=dec_config.kv_heads,
416
  head_dim=dec_config.gqa_head_dim,
417
+ compute_dtype=compute_dtype,
418
  is_cross_attn=False,
419
  out_embed_dim=dec_embed_dim,
420
  )
 
426
  num_query_heads=dec_config.cross_query_heads,
427
  num_kv_heads=dec_config.cross_query_heads,
428
  head_dim=dec_config.cross_head_dim,
429
+ compute_dtype=compute_dtype,
430
  is_cross_attn=True,
431
  out_embed_dim=dec_embed_dim,
432
  )
433
  # MLP
434
  self.mlp = MlpBlock(
 
435
  embed_dim=dec_embed_dim,
436
  intermediate_dim=dec_config.n_hidden,
437
+ compute_dtype=compute_dtype,
 
 
438
  )
439
 
440
  def forward(
441
  self,
442
  x: torch.Tensor,
443
+ state: DecoderInferenceState,
444
+ self_attn_cache: KVCache | None = None,
445
+ cross_attn_cache: KVCache | None = None,
 
 
 
 
 
446
  prefill: bool = False,
447
  ) -> torch.Tensor:
448
  residual = x
449
  x_norm = self.pre_sa_norm(x)
450
 
451
+ sa_out = self.self_attention(
452
  Xq=x_norm, # (2, 1, D)
453
  Xkv=x_norm, # (2, 1, D)
454
+ q_positions=state.dec_positions, # (2, 1)
455
+ kv_positions=state.dec_positions, # (2, 1)
456
+ attn_mask=None,
 
457
  cache=self_attn_cache,
458
  prefill=prefill,
459
+ is_causal=prefill,
460
  )
461
 
462
  x = residual + sa_out
463
 
 
464
  residual = x
465
  x_norm = self.pre_ca_norm(x)
466
+ ca_out = self.cross_attention(
467
  Xq=x_norm,
468
+ Xkv=state.enc_out,
469
+ q_positions=state.dec_positions,
470
+ kv_positions=state.enc_positions,
471
+ attn_mask=state.dec_cross_attn_mask,
 
472
  cache=cross_attn_cache,
473
  )
474
  x = residual + ca_out
475
 
 
476
  residual = x
477
  x_norm = self.pre_mlp_norm(x)
478
+ mlp_out = self.mlp(x_norm)
479
  x = residual + mlp_out
480
 
481
+ return x
482
 
483
 
484
  class Decoder(nn.Module):
485
  """Transformer Decoder Stack using DenseGeneral."""
486
 
487
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
488
  super().__init__()
489
  self.config = config
490
  model_config = config.model
491
  dec_config = config.model.decoder
 
492
  data_config = config.data
 
 
493
  self.num_channels = data_config.channels
494
  self.num_layers = dec_config.n_layer
495
 
496
  self.embeddings = nn.ModuleList(
497
  [
498
+ nn.Embedding(
499
+ model_config.tgt_vocab_size, dec_config.n_embd, dtype=compute_dtype
500
+ )
501
  for _ in range(self.num_channels)
502
  ]
503
  )
504
+ self.layers = nn.ModuleList(
505
+ [
506
+ DecoderLayer(config=config, compute_dtype=compute_dtype)
507
+ for _ in range(self.num_layers)
508
+ ]
509
+ )
510
+
511
  self.norm = RMSNorm(
512
  dec_config.n_embd,
513
  eps=model_config.normalization_layer_epsilon,
514
  dtype=torch.float32,
515
  )
516
 
 
517
  self.logits_dense = DenseGeneral(
518
  in_shapes=(dec_config.n_embd,),
519
  out_features=(self.num_channels, model_config.tgt_vocab_size),
520
  axis=(-1,),
521
+ weight_dtype=compute_dtype,
 
522
  )
 
523
 
524
+ def precompute_cross_attn_cache(
525
  self,
526
+ enc_out: torch.Tensor, # (B, S, E)
527
+ enc_positions: torch.Tensor, # (B, S)
 
528
  ) -> list[KVCache]:
529
  """
530
  Computes the Key and Value tensors for cross-attention for each layer from the encoder output.
 
533
 
534
  for layer in self.layers:
535
  cross_attn_module = layer.cross_attention
536
+ k_proj = cross_attn_module.k_proj(enc_out)
537
+ v_proj = cross_attn_module.v_proj(enc_out)
538
 
539
+ k_proj = cross_attn_module.rotary_emb(k_proj, position=enc_positions)
540
  k = k_proj.transpose(1, 2)
541
  v = v_proj.transpose(1, 2)
542
 
543
+ per_layer_kv_cache.append(KVCache.from_kv(k, v))
 
 
 
 
 
 
 
 
 
544
 
545
  return per_layer_kv_cache
546
 
547
  def decode_step(
548
  self,
549
  tgt_ids_Bx1xC: torch.Tensor, # [B, 1, C]
550
+ state: DecoderInferenceState,
 
 
 
 
 
551
  ) -> torch.Tensor:
552
  """
553
  Performs a single decoding step, managing KV caches layer by layer.
 
556
  A tuple containing:
557
  - logits_Bx1xCV: The final output logits for the current step (B, 1, C*V), cast to float32.
558
  """
 
559
 
560
  x = None
561
  for i in range(self.num_channels):
 
563
  channel_embed = self.embeddings[i](channel_tokens)
564
  x = channel_embed if x is None else x + channel_embed
565
 
 
 
566
  for i, layer in enumerate(self.layers):
567
+ self_cache = state.self_attn_cache[i]
568
+ cross_cache = state.cross_attn_cache[i]
569
+ x = layer(
570
  x, # (2, 1, D)
571
+ state,
 
 
 
 
 
572
  self_attn_cache=self_cache,
573
  cross_attn_cache=cross_cache,
574
  )
 
575
 
576
  x = self.norm(x)
577
  logits_Bx1xCxV = self.logits_dense(x)
578
 
579
+ return logits_Bx1xCxV.to(torch.float32)
580
 
581
  def forward(
582
+ self, tgt_ids_BxTxC: torch.Tensor, state: DecoderInferenceState
 
 
 
 
 
 
 
 
 
583
  ) -> torch.Tensor:
584
  """
585
  Forward pass for the Decoder stack, managing KV caches.
 
589
  encoder_out: Output from the encoder (B, S, E).
590
  tgt_positions: Positions for target sequence (B, T).
591
  src_positions: Positions for source sequence (B, S).
 
592
  self_attn_mask: Mask for self-attention.
593
  cross_attn_mask: Mask for cross-attention.
594
  past_key_values: List containing the self-attention KV cache for each layer
 
614
  channel_embed = self.embeddings[i](channel_tokens)
615
  x = channel_embed if x is None else x + channel_embed
616
 
 
 
 
617
  for i, layer in enumerate(self.layers):
618
+ self_cache = state.self_attn_cache[i]
619
+ cross_cache = state.cross_attn_cache[i]
620
+ x = layer(
621
  x,
622
+ state,
623
+ self_attn_cache=self_cache,
624
+ cross_attn_cache=cross_cache,
 
 
 
 
 
625
  prefill=True,
626
  )
627
 
 
635
  class DiaModel(nn.Module):
636
  """PyTorch Dia Model using DenseGeneral."""
637
 
638
+ def __init__(self, config: DiaConfig, compute_dtype: torch.dtype):
639
  super().__init__()
640
  self.config = config
641
+ self.encoder = Encoder(config, compute_dtype)
642
+ self.decoder = Decoder(config, compute_dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dia/model.py CHANGED
@@ -1,26 +1,46 @@
 
 
 
1
  import dac
2
  import numpy as np
3
  import torch
4
  import torchaudio
5
  from huggingface_hub import hf_hub_download
6
 
7
- from .audio import audio_to_codebook, codebook_to_audio
 
 
 
 
 
 
8
  from .config import DiaConfig
9
- from .layers import DiaModel, KVCache
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  def _sample_next_token(
13
  logits_BCxV: torch.Tensor,
14
  temperature: float,
15
  top_p: float,
16
- use_cfg_filter: bool,
17
  cfg_filter_top_k: int | None = None,
18
  ) -> torch.Tensor:
19
  if temperature == 0.0:
20
  return torch.argmax(logits_BCxV, dim=-1)
21
 
22
  logits_BCxV = logits_BCxV / temperature
23
- if use_cfg_filter and cfg_filter_top_k is not None:
24
  _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
25
  mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
26
  mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
@@ -28,17 +48,21 @@ def _sample_next_token(
28
 
29
  if top_p < 1.0:
30
  probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
31
- sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True)
 
 
32
  cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
33
 
34
- # Calculate indices to remove based on top_p
35
  sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
36
- # Shift the mask to the right to keep the first token above the threshold
37
- sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[..., :-1].clone()
38
- sorted_indices_to_remove_BCxV[..., 0] = 0 # Always keep the most probable token
 
39
 
40
  indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
41
- indices_to_remove_BCxV.scatter_(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV)
 
 
42
  logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
43
 
44
  final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
@@ -48,31 +72,61 @@ def _sample_next_token(
48
  return sampled_indices_C
49
 
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  class Dia:
52
- def __init__(self, config: DiaConfig, device: torch.device = torch.device("cuda")):
 
 
 
 
 
53
  """Initializes the Dia model.
54
 
55
  Args:
56
  config: The configuration object for the model.
57
- device: The device to load the model onto.
58
 
59
  Raises:
60
  RuntimeError: If there is an error loading the DAC model.
61
  """
62
  super().__init__()
63
  self.config = config
64
- self.device = device
65
- self.model = DiaModel(config)
 
 
 
66
  self.dac_model = None
67
 
68
  @classmethod
69
- def from_local(cls, config_path: str, checkpoint_path: str, device: torch.device = torch.device("cuda")) -> "Dia":
 
 
 
 
 
 
70
  """Loads the Dia model from local configuration and checkpoint files.
71
 
72
  Args:
73
  config_path: Path to the configuration JSON file.
74
  checkpoint_path: Path to the model checkpoint (.pth) file.
75
- device: The device to load the model onto.
76
 
77
  Returns:
78
  An instance of the Dia model loaded with weights and set to eval mode.
@@ -85,23 +139,29 @@ class Dia:
85
  if config is None:
86
  raise FileNotFoundError(f"Config file not found at {config_path}")
87
 
88
- dia = cls(config, device)
89
 
90
  try:
91
- dia.model.load_state_dict(torch.load(checkpoint_path, map_location=device))
 
92
  except FileNotFoundError:
93
  raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
94
  except Exception as e:
95
- raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}") from e
 
 
96
 
97
- dia.model.to(device)
98
  dia.model.eval()
99
  dia._load_dac_model()
100
  return dia
101
 
102
  @classmethod
103
  def from_pretrained(
104
- cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda")
 
 
 
105
  ) -> "Dia":
106
  """Loads the Dia model from a Hugging Face Hub repository.
107
 
@@ -110,7 +170,7 @@ class Dia:
110
 
111
  Args:
112
  model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
113
- device: The device to load the model onto.
114
 
115
  Returns:
116
  An instance of the Dia model loaded with weights and set to eval mode.
@@ -121,7 +181,7 @@ class Dia:
121
  """
122
  config_path = hf_hub_download(repo_id=model_name, filename="config.json")
123
  checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
124
- return cls.from_local(config_path, checkpoint_path, device)
125
 
126
  def _load_dac_model(self):
127
  try:
@@ -131,44 +191,7 @@ class Dia:
131
  raise RuntimeError("Failed to load DAC model") from e
132
  self.dac_model = dac_model
133
 
134
- def _create_attn_mask(
135
- self,
136
- q_padding_mask_1d: torch.Tensor,
137
- k_padding_mask_1d: torch.Tensor,
138
- is_causal: bool = False,
139
- ) -> torch.Tensor:
140
- """
141
- Creates the attention mask (self or cross) mimicking JAX segment ID logic.
142
- """
143
- B1, Tq = q_padding_mask_1d.shape
144
- B2, Tk = k_padding_mask_1d.shape
145
- assert B1 == B2, "Query and key batch dimensions must match"
146
-
147
- p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
148
- p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
149
-
150
- # Condition A: Non-padding query attends to non-padding key
151
- non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
152
-
153
- # Condition B: Padding query attends to padding key
154
- pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
155
-
156
- # Combine: True if padding status is compatible (both non-pad OR both pad)
157
- # This implementation follows Jax TPU splash attention kernel
158
- mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
159
-
160
- if is_causal:
161
- # Ensure causality for self-attention (Tq == Tk)
162
- assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal"
163
- # Standard lower-triangular causal mask (True means allow)
164
- causal_mask_2d = torch.tril(torch.ones((Tq, Tk), dtype=torch.bool, device=self.device)) # Shape [Tq, Tk]
165
- causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
166
- return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
167
- else:
168
- # For cross-attention or non-causal self-attention
169
- return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads
170
-
171
- def _prepare_text_input(self, text: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
172
  """Encodes text prompt, pads, and creates attention mask and positions."""
173
  text_pad_value = self.config.data.text_pad_value
174
  max_len = self.config.data.text_length
@@ -190,14 +213,168 @@ class Dia:
190
  constant_values=text_pad_value,
191
  ).astype(np.uint8)
192
 
193
- src_tokens = torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) # [1, S]
194
- src_positions = torch.arange(max_len, device=self.device).to(torch.long).unsqueeze(0) # [1, S]
 
 
195
 
196
- src_padding_mask = (src_tokens != text_pad_value).to(self.device) # [1, S]
 
 
 
 
 
 
 
197
 
198
- enc_self_attn_mask = self._create_attn_mask(src_padding_mask, src_padding_mask, is_causal=False) # [1, S, S]
 
 
 
 
 
199
 
200
- return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  @torch.inference_mode()
203
  def generate(
@@ -207,225 +384,105 @@ class Dia:
207
  cfg_scale: float = 3.0,
208
  temperature: float = 1.3,
209
  top_p: float = 0.95,
210
- use_cfg_filter: bool = True,
211
- use_torch_compile: bool = True,
212
- cfg_filter_top_k: int = 100,
213
  audio_prompt_path: str | None = None,
 
 
214
  ) -> np.ndarray:
215
- """
216
- Generates audio from a text prompt (and optional audio prompt) using the Nari model.
217
-
218
- Returns:
219
- A tensor of generated audio codes (shape: [max_tokens, num_channels]).
220
- """
221
- num_channels = self.config.data.channels
222
- audio_bos_value = self.config.data.audio_bos_value
223
  audio_eos_value = self.config.data.audio_eos_value
224
  audio_pad_value = self.config.data.audio_pad_value
225
  delay_pattern = self.config.data.delay_pattern
226
  max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
227
- delay_tensor = torch.tensor(delay_pattern, dtype=torch.long, device=self.device)
228
  max_delay_pattern = max(delay_pattern)
229
  self.model.eval()
230
 
231
- (
232
- cond_src_BxS,
233
- cond_src_positions_BxS,
234
- cond_src_padding_mask_BxS,
235
- cond_enc_self_attn_mask_Bx1xSxS,
236
- ) = self._prepare_text_input(text)
237
-
238
- unc_src_BxS = torch.zeros_like(cond_src_BxS)
239
- src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0)
240
- src_positions_BxS = cond_src_positions_BxS.expand(2, -1)
241
- src_padding_mask_BxS = cond_src_padding_mask_BxS.expand(2, -1)
242
- enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(2, -1, -1, -1)
243
-
244
- # 2. Encoder Pass
245
- # with torch.autocast(device_type="cuda", dtype=forward_dtype):
246
- encoder_out = self.model.encoder(
247
- x_ids=src_BxS,
248
- src_positions=src_positions_BxS,
249
- deterministic=True,
250
- attn_mask=enc_self_attn_mask_Bx1xSxS,
251
- ) # Shape: (B, S, E)
252
-
253
- # 3. Prepare Decoder Inputs
254
- # 3-1. Allocate KV Cache (Static)
255
- decoder_cross_attention_cache: list[KVCache] = self.model.decoder.precompute_cross_attention_kv(
256
- max_tokens, encoder_out, src_positions_BxS
257
- )
258
-
259
- decoder_self_attention_cache: list[KVCache] = []
260
- for _ in range(self.model.decoder.num_layers):
261
- decoder_self_attention_cache.append(
262
- KVCache(
263
- self.config.model.decoder.gqa_query_heads,
264
- max_tokens,
265
- self.config.model.decoder.gqa_head_dim,
266
- self.device,
267
- )
268
- )
269
-
270
- # 3-2. Initialize Decoder Inputs
271
- generated_BxTxC = torch.full(
272
- (2, 1, num_channels),
273
- fill_value=audio_bos_value,
274
- dtype=torch.long,
275
- device=self.device,
276
- )
277
-
278
- current_step = 0
279
- prompt_len_inc_bos = 1 # Start with BOS length
280
-
281
- # 3-3. Load Audio Prompt (if provided)
282
- if audio_prompt_path is not None:
283
- audio_prompt, sr = torchaudio.load(audio_prompt_path, channels_first=True) # C, T
284
- if sr != 44100: # Resample to 44.1kHz
285
- audio_prompt = torchaudio.functional.resample(audio_prompt, sr, 44100)
286
- audio_prompt = audio_prompt.to(self.device).unsqueeze(0) # 1, C, T
287
- audio_prompt = audio_to_codebook(self.dac_model, audio_prompt, data_config=self.config.data)
288
- generated_BxTxC = torch.cat([generated_BxTxC, audio_prompt.expand(2, -1, -1)], dim=1)
289
-
290
- prefill_len = generated_BxTxC.shape[1]
291
- prompt_len_inc_bos = prefill_len
292
- prefill_tgt_pos = torch.arange(prefill_len, device=self.device).unsqueeze(0).expand(2, -1)
293
- prefill_tgt_padding_mask = (generated_BxTxC != audio_pad_value).any(dim=2)
294
-
295
- prefill_self_attn_mask = self._create_attn_mask(
296
- prefill_tgt_padding_mask,
297
- prefill_tgt_padding_mask,
298
- is_causal=True,
299
- )
300
- prefill_cross_attn_mask = self._create_attn_mask(
301
- prefill_tgt_padding_mask,
302
- src_padding_mask_BxS,
303
- is_causal=False,
304
- )
305
 
306
- _ = self.model.decoder.forward(
307
- tgt_ids_BxTxC=generated_BxTxC,
308
- encoder_out=encoder_out,
309
- tgt_positions=prefill_tgt_pos,
310
- src_positions=src_positions_BxS,
311
- deterministic=True,
312
- self_attn_mask=prefill_self_attn_mask,
313
- cross_attn_mask=prefill_cross_attn_mask,
314
- self_attention_cache=decoder_self_attention_cache,
315
- cross_attention_cache=decoder_cross_attention_cache,
316
- )
317
 
318
- current_step = prefill_len - 1
 
319
 
320
- # 4. Autoregressive Generation Loop
321
- eos_detected_channel_0 = False
322
  eos_countdown = -1
323
- extra_steps_after_eos = 30
324
- # Make generated_BxTxC a fixed size tensor
325
- # Length is either 1 + max tokens or 1 + prompt len + max tokens
326
- generated_BxTxC = torch.cat(
327
- [
328
- generated_BxTxC,
329
- torch.full(
330
- (2, max_tokens, num_channels),
331
- fill_value=-1,
332
- dtype=torch.long,
333
- device=self.device,
334
- ),
335
- ],
336
- dim=1,
337
- )
338
 
339
- decode_step = self.model.decoder.decode_step
340
  if use_torch_compile:
341
- decode_step = torch.compile(
342
- self.model.decoder.decode_step,
343
- mode="default",
344
- )
345
 
346
- tgt_padding_mask = (
347
- (generated_BxTxC[:, -1, :].unsqueeze(1) != audio_pad_value).any(dim=2).to(self.device)
348
- ) # [B, 1]
349
- # Generated tokens are never PAD, so we use fixed mask
350
- decoder_cross_attn_mask = self._create_attn_mask(
351
- tgt_padding_mask, # Query mask [B, 1]
352
- src_padding_mask_BxS, # Key mask [B, S]
353
- is_causal=False,
354
- ) # [B, 1, 1, S]
355
-
356
- for step in range(current_step, current_step + max_tokens):
357
- tgt_ids_Bx1xC = generated_BxTxC[:, step, :].unsqueeze(1)
358
- tgt_pos_Bx1 = torch.full(
359
- (2, 1),
360
- fill_value=step,
361
- dtype=torch.long,
362
- device=self.device,
363
- )
364
 
365
- logits_Bx1xCxV, new_cache = decode_step(
366
- tgt_ids_Bx1xC=tgt_ids_Bx1xC,
367
- tgt_pos_Bx1=tgt_pos_Bx1,
368
- encoder_out=encoder_out,
369
- self_attn_mask=None,
370
- cross_attn_mask=decoder_cross_attn_mask,
371
- self_attention_cache=decoder_self_attention_cache,
372
- cross_attention_cache=decoder_cross_attention_cache,
373
  )
374
-
375
- for i, layer_cache in enumerate(decoder_self_attention_cache):
376
- layer_cache.update_cache(new_cache[i][0], new_cache[i][1])
377
-
378
- V = self.config.model.tgt_vocab_size
379
- logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] # B, C, V
380
- uncond_logits_CxV = logits_last_BxCxV[0, :, :]
381
- cond_logits_CxV = logits_last_BxCxV[1, :, :]
382
-
383
- cfg_logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
384
-
385
- logits_CxV = cfg_logits_CxV.reshape((-1, V)) # C, V
386
- logits_CxV[:, 1025:] = -torch.inf
387
-
388
- # Sample next token
389
- pred_C = _sample_next_token(
390
- logits_CxV.float(),
391
- temperature=temperature,
392
- top_p=top_p,
393
- use_cfg_filter=use_cfg_filter,
394
- cfg_filter_top_k=cfg_filter_top_k,
395
  )
396
 
397
- generation_step_index = step - current_step
398
- if audio_prompt_path is None:
399
- pred_C = torch.where(
400
- generation_step_index >= delay_tensor,
401
- pred_C,
402
- audio_bos_value,
403
- )
404
-
405
- generated_BxTxC[:, step + 1, :] = pred_C.unsqueeze(0).expand(2, -1)
406
-
407
- if not eos_detected_channel_0 and pred_C[0] == audio_eos_value:
408
- eos_detected_channel_0 = True
409
- eos_countdown = extra_steps_after_eos
410
 
411
  if eos_countdown > 0:
412
  step_after_eos = max_delay_pattern - eos_countdown
413
  for i, d in enumerate(delay_pattern):
414
  if step_after_eos == d:
415
- generated_BxTxC[:, step + 1, i] = audio_eos_value
416
  elif step_after_eos > d:
417
- generated_BxTxC[:, step + 1, i] = audio_pad_value
418
  eos_countdown -= 1
419
- if eos_countdown == 0:
420
- break
421
 
422
- generation_step_index = step - current_step + 1
 
423
 
424
- output_codes = generated_BxTxC[:, prompt_len_inc_bos : step + 1, :]
 
425
 
426
- generated_codes = output_codes[0]
 
 
 
 
 
 
427
 
428
- audio = codebook_to_audio(
429
- generated_codes.transpose(1, 0), self.dac_model, delay_pattern, B=1, T=max_tokens, C=num_channels
430
- )
431
- return audio.squeeze().cpu().numpy()
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from enum import Enum
3
+
4
  import dac
5
  import numpy as np
6
  import torch
7
  import torchaudio
8
  from huggingface_hub import hf_hub_download
9
 
10
+ from .audio import (
11
+ apply_audio_delay,
12
+ build_delay_indices,
13
+ build_revert_indices,
14
+ decode,
15
+ revert_audio_delay,
16
+ )
17
  from .config import DiaConfig
18
+ from .layers import DiaModel
19
+ from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState
20
+
21
+
22
+ DEFAULT_SAMPLE_RATE = 44100
23
+
24
+
25
+ def _get_default_device():
26
+ if torch.cuda.is_available():
27
+ return torch.device("cuda")
28
+ elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
29
+ return torch.device("mps")
30
+ return torch.device("cpu")
31
 
32
 
33
  def _sample_next_token(
34
  logits_BCxV: torch.Tensor,
35
  temperature: float,
36
  top_p: float,
 
37
  cfg_filter_top_k: int | None = None,
38
  ) -> torch.Tensor:
39
  if temperature == 0.0:
40
  return torch.argmax(logits_BCxV, dim=-1)
41
 
42
  logits_BCxV = logits_BCxV / temperature
43
+ if cfg_filter_top_k is not None:
44
  _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1)
45
  mask = torch.ones_like(logits_BCxV, dtype=torch.bool)
46
  mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False)
 
48
 
49
  if top_p < 1.0:
50
  probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
51
+ sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(
52
+ probs_BCxV, dim=-1, descending=True
53
+ )
54
  cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1)
55
 
 
56
  sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p
57
+ sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[
58
+ ..., :-1
59
+ ].clone()
60
+ sorted_indices_to_remove_BCxV[..., 0] = 0
61
 
62
  indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV)
63
+ indices_to_remove_BCxV.scatter_(
64
+ dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV
65
+ )
66
  logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf)
67
 
68
  final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1)
 
72
  return sampled_indices_C
73
 
74
 
75
+ class ComputeDtype(str, Enum):
76
+ FLOAT32 = "float32"
77
+ FLOAT16 = "float16"
78
+ BFLOAT16 = "bfloat16"
79
+
80
+ def to_dtype(self) -> torch.dtype:
81
+ if self == ComputeDtype.FLOAT32:
82
+ return torch.float32
83
+ elif self == ComputeDtype.FLOAT16:
84
+ return torch.float16
85
+ elif self == ComputeDtype.BFLOAT16:
86
+ return torch.bfloat16
87
+ else:
88
+ raise ValueError(f"Unsupported compute dtype: {self}")
89
+
90
+
91
  class Dia:
92
+ def __init__(
93
+ self,
94
+ config: DiaConfig,
95
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
96
+ device: torch.device | None = None,
97
+ ):
98
  """Initializes the Dia model.
99
 
100
  Args:
101
  config: The configuration object for the model.
102
+ device: The device to load the model onto. If None, will automatically select the best available device.
103
 
104
  Raises:
105
  RuntimeError: If there is an error loading the DAC model.
106
  """
107
  super().__init__()
108
  self.config = config
109
+ self.device = device if device is not None else _get_default_device()
110
+ if isinstance(compute_dtype, str):
111
+ compute_dtype = ComputeDtype(compute_dtype)
112
+ self.compute_dtype = compute_dtype.to_dtype()
113
+ self.model = DiaModel(config, self.compute_dtype)
114
  self.dac_model = None
115
 
116
  @classmethod
117
+ def from_local(
118
+ cls,
119
+ config_path: str,
120
+ checkpoint_path: str,
121
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
122
+ device: torch.device | None = None,
123
+ ) -> "Dia":
124
  """Loads the Dia model from local configuration and checkpoint files.
125
 
126
  Args:
127
  config_path: Path to the configuration JSON file.
128
  checkpoint_path: Path to the model checkpoint (.pth) file.
129
+ device: The device to load the model onto. If None, will automatically select the best available device.
130
 
131
  Returns:
132
  An instance of the Dia model loaded with weights and set to eval mode.
 
139
  if config is None:
140
  raise FileNotFoundError(f"Config file not found at {config_path}")
141
 
142
+ dia = cls(config, compute_dtype, device)
143
 
144
  try:
145
+ state_dict = torch.load(checkpoint_path, map_location=dia.device)
146
+ dia.model.load_state_dict(state_dict)
147
  except FileNotFoundError:
148
  raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}")
149
  except Exception as e:
150
+ raise RuntimeError(
151
+ f"Error loading checkpoint from {checkpoint_path}"
152
+ ) from e
153
 
154
+ dia.model.to(dia.device)
155
  dia.model.eval()
156
  dia._load_dac_model()
157
  return dia
158
 
159
  @classmethod
160
  def from_pretrained(
161
+ cls,
162
+ model_name: str = "nari-labs/Dia-1.6B",
163
+ compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32,
164
+ device: torch.device | None = None,
165
  ) -> "Dia":
166
  """Loads the Dia model from a Hugging Face Hub repository.
167
 
 
170
 
171
  Args:
172
  model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B").
173
+ device: The device to load the model onto. If None, will automatically select the best available device.
174
 
175
  Returns:
176
  An instance of the Dia model loaded with weights and set to eval mode.
 
181
  """
182
  config_path = hf_hub_download(repo_id=model_name, filename="config.json")
183
  checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth")
184
+ return cls.from_local(config_path, checkpoint_path, compute_dtype, device)
185
 
186
  def _load_dac_model(self):
187
  try:
 
191
  raise RuntimeError("Failed to load DAC model") from e
192
  self.dac_model = dac_model
193
 
194
+ def _prepare_text_input(self, text: str) -> torch.Tensor:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
  """Encodes text prompt, pads, and creates attention mask and positions."""
196
  text_pad_value = self.config.data.text_pad_value
197
  max_len = self.config.data.text_length
 
213
  constant_values=text_pad_value,
214
  ).astype(np.uint8)
215
 
216
+ src_tokens = (
217
+ torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0)
218
+ ) # [1, S]
219
+ return src_tokens
220
 
221
+ def _prepare_audio_prompt(
222
+ self, audio_prompt: torch.Tensor | None
223
+ ) -> tuple[torch.Tensor, int]:
224
+ num_channels = self.config.data.channels
225
+ audio_bos_value = self.config.data.audio_bos_value
226
+ audio_pad_value = self.config.data.audio_pad_value
227
+ delay_pattern = self.config.data.delay_pattern
228
+ max_delay_pattern = max(delay_pattern)
229
 
230
+ prefill = torch.full(
231
+ (1, num_channels),
232
+ fill_value=audio_bos_value,
233
+ dtype=torch.int,
234
+ device=self.device,
235
+ )
236
 
237
+ prefill_step = 1
238
+
239
+ if audio_prompt is not None:
240
+ prefill_step += audio_prompt.shape[0]
241
+ prefill = torch.cat([prefill, audio_prompt], dim=0)
242
+
243
+ delay_pad_tensor = torch.full(
244
+ (max_delay_pattern, num_channels),
245
+ fill_value=-1,
246
+ dtype=torch.int,
247
+ device=self.device,
248
+ )
249
+ prefill = torch.cat([prefill, delay_pad_tensor], dim=0)
250
+
251
+ delay_precomp = build_delay_indices(
252
+ B=1,
253
+ T=prefill.shape[0],
254
+ C=num_channels,
255
+ delay_pattern=delay_pattern,
256
+ )
257
+
258
+ prefill = apply_audio_delay(
259
+ audio_BxTxC=prefill.unsqueeze(0),
260
+ pad_value=audio_pad_value,
261
+ bos_value=audio_bos_value,
262
+ precomp=delay_precomp,
263
+ ).squeeze(0)
264
+
265
+ return prefill, prefill_step
266
+
267
+ def _prepare_generation(
268
+ self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool
269
+ ):
270
+ enc_input_cond = self._prepare_text_input(text)
271
+ enc_input_uncond = torch.zeros_like(enc_input_cond)
272
+ enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0)
273
+
274
+ if isinstance(audio_prompt, str):
275
+ audio_prompt = self.load_audio(audio_prompt)
276
+ prefill, prefill_step = self._prepare_audio_prompt(audio_prompt)
277
+
278
+ if verbose:
279
+ print("generate: data loaded")
280
+
281
+ enc_state = EncoderInferenceState.new(self.config, enc_input_cond)
282
+ encoder_out = self.model.encoder(enc_input, enc_state)
283
+
284
+ dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache(
285
+ encoder_out, enc_state.positions
286
+ )
287
+ dec_state = DecoderInferenceState.new(
288
+ self.config,
289
+ enc_state,
290
+ encoder_out,
291
+ dec_cross_attn_cache,
292
+ self.compute_dtype,
293
+ )
294
+ dec_output = DecoderOutput.new(self.config, self.device)
295
+ dec_output.prefill(prefill, prefill_step)
296
+
297
+ dec_step = prefill_step - 1
298
+ if dec_step > 0:
299
+ dec_state.prepare_step(0, dec_step)
300
+ tokens_BxTxC = (
301
+ dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1)
302
+ )
303
+ self.model.decoder.forward(tokens_BxTxC, dec_state)
304
+
305
+ return dec_state, dec_output
306
+
307
+ def _decoder_step(
308
+ self,
309
+ tokens_Bx1xC: torch.Tensor,
310
+ dec_state: DecoderInferenceState,
311
+ cfg_scale: float,
312
+ temperature: float,
313
+ top_p: float,
314
+ cfg_filter_top_k: int,
315
+ ) -> torch.Tensor:
316
+ audio_eos_value = self.config.data.audio_eos_value
317
+ logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state)
318
+
319
+ logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :]
320
+ uncond_logits_CxV = logits_last_BxCxV[0, :, :]
321
+ cond_logits_CxV = logits_last_BxCxV[1, :, :]
322
+
323
+ logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV)
324
+ logits_CxV[:, audio_eos_value + 1 :] = -torch.inf
325
+ logits_CxV[1:, audio_eos_value:] = -torch.inf
326
+
327
+ pred_C = _sample_next_token(
328
+ logits_CxV.float(),
329
+ temperature=temperature,
330
+ top_p=top_p,
331
+ cfg_filter_top_k=cfg_filter_top_k,
332
+ )
333
+ return pred_C
334
+
335
+ def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray:
336
+ num_channels = self.config.data.channels
337
+ seq_length = generated_codes.shape[0]
338
+ delay_pattern = self.config.data.delay_pattern
339
+ audio_pad_value = self.config.data.audio_pad_value
340
+ max_delay_pattern = max(delay_pattern)
341
+
342
+ revert_precomp = build_revert_indices(
343
+ B=1,
344
+ T=seq_length,
345
+ C=num_channels,
346
+ delay_pattern=delay_pattern,
347
+ )
348
+
349
+ codebook = revert_audio_delay(
350
+ audio_BxTxC=generated_codes.unsqueeze(0),
351
+ pad_value=audio_pad_value,
352
+ precomp=revert_precomp,
353
+ T=seq_length,
354
+ )[:, :-max_delay_pattern, :]
355
+
356
+ min_valid_index = 0
357
+ max_valid_index = 1023
358
+ invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index)
359
+ codebook[invalid_mask] = 0
360
+
361
+ audio = decode(self.dac_model, codebook.transpose(1, 2))
362
+
363
+ return audio.squeeze().cpu().numpy()
364
+
365
+ def load_audio(self, audio_path: str) -> torch.Tensor:
366
+ audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T
367
+ if sr != DEFAULT_SAMPLE_RATE:
368
+ audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE)
369
+ audio = audio.to(self.device).unsqueeze(0) # 1, C, T
370
+ audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE)
371
+ _, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T
372
+ return encoded_frame.squeeze(0).transpose(0, 1)
373
+
374
+ def save_audio(self, path: str, audio: np.ndarray):
375
+ import soundfile as sf
376
+
377
+ sf.write(path, audio, DEFAULT_SAMPLE_RATE)
378
 
379
  @torch.inference_mode()
380
  def generate(
 
384
  cfg_scale: float = 3.0,
385
  temperature: float = 1.3,
386
  top_p: float = 0.95,
387
+ use_torch_compile: bool = False,
388
+ cfg_filter_top_k: int = 35,
389
+ audio_prompt: str | torch.Tensor | None = None,
390
  audio_prompt_path: str | None = None,
391
+ use_cfg_filter: bool | None = None,
392
+ verbose: bool = False,
393
  ) -> np.ndarray:
 
 
 
 
 
 
 
 
394
  audio_eos_value = self.config.data.audio_eos_value
395
  audio_pad_value = self.config.data.audio_pad_value
396
  delay_pattern = self.config.data.delay_pattern
397
  max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens
 
398
  max_delay_pattern = max(delay_pattern)
399
  self.model.eval()
400
 
401
+ if audio_prompt_path:
402
+ print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.")
403
+ audio_prompt = audio_prompt_path
404
+ if use_cfg_filter is not None:
405
+ print("Warning: use_cfg_filter is deprecated.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
406
 
407
+ if verbose:
408
+ total_start_time = time.time()
 
 
 
 
 
 
 
 
 
409
 
410
+ dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose)
411
+ dec_step = dec_output.prefill_step - 1
412
 
413
+ bos_countdown = max_delay_pattern
414
+ eos_detected = False
415
  eos_countdown = -1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
416
 
 
417
  if use_torch_compile:
418
+ step_fn = torch.compile(self._decoder_step, mode="default")
419
+ else:
420
+ step_fn = self._decoder_step
 
421
 
422
+ if verbose:
423
+ print("generate: starting generation loop")
424
+ if use_torch_compile:
425
+ print(
426
+ "generate: by using use_torch_compile=True, the first step would take long"
427
+ )
428
+ start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
429
 
430
+ while dec_step < max_tokens:
431
+ dec_state.prepare_step(dec_step)
432
+ tokens_Bx1xC = (
433
+ dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1)
 
 
 
 
434
  )
435
+ pred_C = step_fn(
436
+ tokens_Bx1xC,
437
+ dec_state,
438
+ cfg_scale,
439
+ temperature,
440
+ top_p,
441
+ cfg_filter_top_k,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
442
  )
443
 
444
+ if (
445
+ not eos_detected and pred_C[0] == audio_eos_value
446
+ ) or dec_step == max_tokens - max_delay_pattern - 1:
447
+ eos_detected = True
448
+ eos_countdown = max_delay_pattern
 
 
 
 
 
 
 
 
449
 
450
  if eos_countdown > 0:
451
  step_after_eos = max_delay_pattern - eos_countdown
452
  for i, d in enumerate(delay_pattern):
453
  if step_after_eos == d:
454
+ pred_C[i] = audio_eos_value
455
  elif step_after_eos > d:
456
+ pred_C[i] = audio_pad_value
457
  eos_countdown -= 1
 
 
458
 
459
+ bos_countdown = max(0, bos_countdown - 1)
460
+ dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0)
461
 
462
+ if eos_countdown == 0:
463
+ break
464
 
465
+ dec_step += 1
466
+ if verbose and dec_step % 86 == 0:
467
+ duration = time.time() - start_time
468
+ print(
469
+ f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x"
470
+ )
471
+ start_time = time.time()
472
 
473
+ if dec_output.prefill_step >= dec_step + 1:
474
+ print("Warning: Nothing generated")
475
+ return None
476
+
477
+ generated_codes = dec_output.generated_tokens[
478
+ dec_output.prefill_step : dec_step + 1, :
479
+ ]
480
+
481
+ if verbose:
482
+ total_step = dec_step + 1 - dec_output.prefill_step
483
+ total_duration = time.time() - total_start_time
484
+ print(
485
+ f"generate: total step={total_step}, total duration={total_duration:.3f}s"
486
+ )
487
+
488
+ return self._generate_output(generated_codes)
dia/state.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from .config import DiaConfig
6
+
7
+
8
+ def create_attn_mask(
9
+ q_padding_mask_1d: torch.Tensor,
10
+ k_padding_mask_1d: torch.Tensor,
11
+ device: torch.device,
12
+ is_causal: bool = False,
13
+ ) -> torch.Tensor:
14
+ """
15
+ Creates the attention mask (self or cross) mimicking JAX segment ID logic.
16
+ """
17
+ B1, Tq = q_padding_mask_1d.shape
18
+ B2, Tk = k_padding_mask_1d.shape
19
+ assert B1 == B2, "Query and key batch dimensions must match"
20
+
21
+ p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1]
22
+ p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk]
23
+
24
+ # Condition A: Non-padding query attends to non-padding key
25
+ non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk]
26
+
27
+ # Condition B: Padding query attends to padding key
28
+ pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk]
29
+
30
+ # Combine: True if padding status is compatible (both non-pad OR both pad)
31
+ mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk]
32
+
33
+ if is_causal:
34
+ assert Tq == Tk, (
35
+ "Causal mask requires query and key sequence lengths to be equal"
36
+ )
37
+ causal_mask_2d = torch.tril(
38
+ torch.ones((Tq, Tk), dtype=torch.bool, device=device)
39
+ ) # Shape [Tq, Tk]
40
+ causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk]
41
+ return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
42
+ else:
43
+ return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk]
44
+
45
+
46
+ @dataclass
47
+ class EncoderInferenceState:
48
+ """Parameters specifically for encoder inference."""
49
+
50
+ max_seq_len: int
51
+ device: torch.device
52
+ positions: torch.Tensor
53
+ padding_mask: torch.Tensor
54
+ attn_mask: torch.Tensor
55
+
56
+ @classmethod
57
+ def new(cls, config: DiaConfig, cond_src: torch.Tensor) -> "EncoderInferenceState":
58
+ """Creates EtorchrInferenceParams from DiaConfig and a device."""
59
+ device = cond_src.device
60
+
61
+ positions = (
62
+ torch.arange(config.data.text_length, device=device)
63
+ .to(torch.long)
64
+ .unsqueeze(0)
65
+ .expand(2, -1)
66
+ )
67
+ padding_mask = (cond_src != config.data.text_pad_value).to(device).expand(2, -1)
68
+ attn_mask = create_attn_mask(
69
+ padding_mask, padding_mask, device, is_causal=False
70
+ )
71
+
72
+ return cls(
73
+ max_seq_len=config.data.text_length,
74
+ device=device,
75
+ positions=positions,
76
+ padding_mask=padding_mask,
77
+ attn_mask=attn_mask,
78
+ )
79
+
80
+
81
+ class KVCache:
82
+ def __init__(
83
+ self,
84
+ num_heads: int,
85
+ max_len: int,
86
+ head_dim: int,
87
+ dtype: torch.dtype,
88
+ device: torch.device,
89
+ k: torch.Tensor | None = None,
90
+ v: torch.Tensor | None = None,
91
+ ):
92
+ self.k = (
93
+ torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
94
+ if k is None
95
+ else k
96
+ )
97
+ self.v = (
98
+ torch.zeros((2, num_heads, max_len, head_dim), dtype=dtype, device=device)
99
+ if v is None
100
+ else v
101
+ )
102
+ self.current_idx = torch.tensor(0)
103
+
104
+ @classmethod
105
+ def from_kv(cls, k: torch.Tensor, v: torch.Tensor) -> "KVCache":
106
+ return cls(
107
+ num_heads=k.shape[1],
108
+ max_len=k.shape[2],
109
+ head_dim=k.shape[3],
110
+ dtype=k.dtype,
111
+ device=k.device,
112
+ k=k,
113
+ v=v,
114
+ )
115
+
116
+ def update(
117
+ self, k: torch.Tensor, v: torch.Tensor
118
+ ) -> tuple[torch.Tensor, torch.Tensor]:
119
+ self.k[:, :, self.current_idx : self.current_idx + 1, :] = k
120
+ self.v[:, :, self.current_idx : self.current_idx + 1, :] = v
121
+ self.current_idx += 1
122
+ return self.k[:, :, : self.current_idx, :], self.v[:, :, : self.current_idx, :]
123
+
124
+ def prefill(
125
+ self, k: torch.Tensor, v: torch.Tensor
126
+ ) -> tuple[torch.Tensor, torch.Tensor]:
127
+ prefill_len = k.shape[2]
128
+ self.k[:, :, :prefill_len, :] = k
129
+ self.v[:, :, :prefill_len, :] = v
130
+ self.current_idx = prefill_len - 1
131
+
132
+
133
+ @dataclass
134
+ class DecoderInferenceState:
135
+ """Parameters specifically for decoder inference."""
136
+
137
+ device: torch.device
138
+ dtype: torch.dtype
139
+ enc_out: torch.Tensor
140
+ enc_positions: torch.Tensor
141
+ dec_positions: torch.Tensor
142
+ dec_cross_attn_mask: torch.Tensor
143
+ self_attn_cache: list[KVCache]
144
+ cross_attn_cache: list[KVCache]
145
+
146
+ @classmethod
147
+ def new(
148
+ cls,
149
+ config: DiaConfig,
150
+ enc_state: EncoderInferenceState,
151
+ enc_out: torch.Tensor,
152
+ dec_cross_attn_cache: list[KVCache],
153
+ compute_dtype: torch.dtype,
154
+ ) -> "DecoderInferenceState":
155
+ """Creates DecoderInferenceParams from DiaConfig and a device."""
156
+ device = enc_out.device
157
+ max_audio_len = config.data.audio_length
158
+
159
+ dec_positions = torch.full(
160
+ (2, 1), fill_value=0, dtype=torch.long, device=device
161
+ )
162
+ tgt_padding_mask = torch.ones((2, 1), dtype=torch.bool, device=device)
163
+ dec_cross_attn_mask = create_attn_mask(
164
+ tgt_padding_mask, enc_state.padding_mask, device, is_causal=False
165
+ )
166
+
167
+ self_attn_cache = [
168
+ KVCache(
169
+ config.model.decoder.kv_heads,
170
+ max_audio_len,
171
+ config.model.decoder.gqa_head_dim,
172
+ compute_dtype,
173
+ device,
174
+ )
175
+ for _ in range(config.model.decoder.n_layer)
176
+ ]
177
+
178
+ return cls(
179
+ device=device,
180
+ dtype=compute_dtype,
181
+ enc_out=enc_out,
182
+ enc_positions=enc_state.positions,
183
+ dec_positions=dec_positions,
184
+ dec_cross_attn_mask=dec_cross_attn_mask,
185
+ self_attn_cache=self_attn_cache,
186
+ cross_attn_cache=dec_cross_attn_cache,
187
+ )
188
+
189
+ def prepare_step(self, step_from: int, step_to: int | None = None) -> None:
190
+ if step_to is None:
191
+ step_to = step_from + 1
192
+ self.dec_positions = (
193
+ torch.arange(step_from, step_to, device=self.device)
194
+ .unsqueeze(0)
195
+ .expand(2, -1)
196
+ )
197
+
198
+
199
+ @dataclass
200
+ class DecoderOutput:
201
+ generated_tokens: torch.Tensor
202
+ prefill_step: int
203
+
204
+ @classmethod
205
+ def new(cls, config: DiaConfig, device: torch.device) -> "DecoderOutput":
206
+ max_audio_len = config.data.audio_length
207
+ return cls(
208
+ generated_tokens=torch.full(
209
+ (max_audio_len, config.data.channels),
210
+ fill_value=-1,
211
+ dtype=torch.int,
212
+ device=device,
213
+ ),
214
+ prefill_step=0,
215
+ )
216
+
217
+ def get_tokens_at(self, step_from: int, step_to: int | None = None) -> torch.Tensor:
218
+ if step_to is None:
219
+ step_to = step_from + 1
220
+ return self.generated_tokens[step_from:step_to, :]
221
+
222
+ def update_one(self, dec_out: torch.Tensor, step: int, apply_mask: bool = False):
223
+ if apply_mask:
224
+ mask = self.generated_tokens[step : step + 1, :] == -1
225
+ self.generated_tokens[step : step + 1, :] = torch.where(
226
+ mask, dec_out, self.generated_tokens[step : step + 1, :]
227
+ )
228
+ else:
229
+ self.generated_tokens[step : step + 1, :] = dec_out
230
+
231
+ def prefill(self, dec_out: torch.Tensor, prefill_step: int):
232
+ length = dec_out.shape[0]
233
+ self.generated_tokens[0:length, :] = dec_out
234
+ self.prefill_step = prefill_step