import dac import numpy as np import torch import torchaudio from huggingface_hub import hf_hub_download from .audio import audio_to_codebook, codebook_to_audio from .config import DiaConfig from .layers import DiaModel, KVCache def _sample_next_token( logits_BCxV: torch.Tensor, temperature: float, top_p: float, use_cfg_filter: bool, cfg_filter_top_k: int | None = None, ) -> torch.Tensor: if temperature == 0.0: return torch.argmax(logits_BCxV, dim=-1) logits_BCxV = logits_BCxV / temperature if use_cfg_filter and cfg_filter_top_k is not None: _, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1) mask = torch.ones_like(logits_BCxV, dtype=torch.bool) mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False) logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf) if top_p < 1.0: probs_BCxV = torch.softmax(logits_BCxV, dim=-1) sorted_probs_BCxV, sorted_indices_BCxV = torch.sort(probs_BCxV, dim=-1, descending=True) cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1) # Calculate indices to remove based on top_p sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p # Shift the mask to the right to keep the first token above the threshold sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[..., :-1].clone() sorted_indices_to_remove_BCxV[..., 0] = 0 # Always keep the most probable token indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV) indices_to_remove_BCxV.scatter_(dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV) logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf) final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1) sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1) sampled_indices_C = sampled_indices_BC.squeeze(-1) return sampled_indices_C class Dia: def __init__(self, config: DiaConfig, device: torch.device = torch.device("cuda")): """Initializes the Dia model. Args: config: The configuration object for the model. device: The device to load the model onto. Raises: RuntimeError: If there is an error loading the DAC model. """ super().__init__() self.config = config self.device = device self.model = DiaModel(config) self.dac_model = None @classmethod def from_local(cls, config_path: str, checkpoint_path: str, device: torch.device = torch.device("cuda")) -> "Dia": """Loads the Dia model from local configuration and checkpoint files. Args: config_path: Path to the configuration JSON file. checkpoint_path: Path to the model checkpoint (.pth) file. device: The device to load the model onto. Returns: An instance of the Dia model loaded with weights and set to eval mode. Raises: FileNotFoundError: If the config or checkpoint file is not found. RuntimeError: If there is an error loading the checkpoint. """ config = DiaConfig.load(config_path) if config is None: raise FileNotFoundError(f"Config file not found at {config_path}") dia = cls(config, device) try: dia.model.load_state_dict(torch.load(checkpoint_path, map_location=device)) except FileNotFoundError: raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}") except Exception as e: raise RuntimeError(f"Error loading checkpoint from {checkpoint_path}") from e dia.model.to(device) dia.model.eval() dia._load_dac_model() return dia @classmethod def from_pretrained( cls, model_name: str = "nari-labs/Dia-1.6B", device: torch.device = torch.device("cuda") ) -> "Dia": """Loads the Dia model from a Hugging Face Hub repository. Downloads the configuration and checkpoint files from the specified repository ID and then loads the model. Args: model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B"). device: The device to load the model onto. Returns: An instance of the Dia model loaded with weights and set to eval mode. Raises: FileNotFoundError: If config or checkpoint download/loading fails. RuntimeError: If there is an error loading the checkpoint. """ config_path = hf_hub_download(repo_id=model_name, filename="config.json") checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth") return cls.from_local(config_path, checkpoint_path, device) def _load_dac_model(self): try: dac_model_path = dac.utils.download() dac_model = dac.DAC.load(dac_model_path).to(self.device) except Exception as e: raise RuntimeError("Failed to load DAC model") from e self.dac_model = dac_model def _create_attn_mask( self, q_padding_mask_1d: torch.Tensor, k_padding_mask_1d: torch.Tensor, is_causal: bool = False, ) -> torch.Tensor: """ Creates the attention mask (self or cross) mimicking JAX segment ID logic. """ B1, Tq = q_padding_mask_1d.shape B2, Tk = k_padding_mask_1d.shape assert B1 == B2, "Query and key batch dimensions must match" p_mask_q = q_padding_mask_1d.unsqueeze(2) # Shape [B, Tq, 1] p_mask_k = k_padding_mask_1d.unsqueeze(1) # Shape [B, 1, Tk] # Condition A: Non-padding query attends to non-padding key non_pad_attends_non_pad = p_mask_q & p_mask_k # Shape [B, Tq, Tk] # Condition B: Padding query attends to padding key pad_attends_pad = (~p_mask_q) & (~p_mask_k) # Shape [B, Tq, Tk] # Combine: True if padding status is compatible (both non-pad OR both pad) # This implementation follows Jax TPU splash attention kernel mask = non_pad_attends_non_pad | pad_attends_pad # Shape [B, Tq, Tk] if is_causal: # Ensure causality for self-attention (Tq == Tk) assert Tq == Tk, "Causal mask requires query and key sequence lengths to be equal" # Standard lower-triangular causal mask (True means allow) causal_mask_2d = torch.tril(torch.ones((Tq, Tk), dtype=torch.bool, device=self.device)) # Shape [Tq, Tk] causal_mask = mask & causal_mask_2d # Shape [B, Tq, Tk] return causal_mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads else: # For cross-attention or non-causal self-attention return mask.unsqueeze(1) # Shape [B, 1, Tq, Tk] for broadcasting across heads def _prepare_text_input(self, text: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Encodes text prompt, pads, and creates attention mask and positions.""" text_pad_value = self.config.data.text_pad_value max_len = self.config.data.text_length byte_text = text.encode("utf-8") replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02") text_tokens = list(replaced_bytes) current_len = len(text_tokens) padding_needed = max_len - current_len if padding_needed <= 0: text_tokens = text_tokens[:max_len] padded_text_np = np.array(text_tokens, dtype=np.uint8) else: padded_text_np = np.pad( text_tokens, (0, padding_needed), mode="constant", constant_values=text_pad_value, ).astype(np.uint8) src_tokens = torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) # [1, S] src_positions = torch.arange(max_len, device=self.device).to(torch.long).unsqueeze(0) # [1, S] src_padding_mask = (src_tokens != text_pad_value).to(self.device) # [1, S] enc_self_attn_mask = self._create_attn_mask(src_padding_mask, src_padding_mask, is_causal=False) # [1, S, S] return src_tokens, src_positions, src_padding_mask, enc_self_attn_mask @torch.inference_mode() def generate( self, text: str, max_tokens: int | None = None, cfg_scale: float = 3.0, temperature: float = 1.3, top_p: float = 0.95, use_cfg_filter: bool = True, use_torch_compile: bool = True, cfg_filter_top_k: int = 100, audio_prompt_path: str | None = None, ) -> np.ndarray: """ Generates audio from a text prompt (and optional audio prompt) using the Nari model. Returns: A tensor of generated audio codes (shape: [max_tokens, num_channels]). """ num_channels = self.config.data.channels audio_bos_value = self.config.data.audio_bos_value audio_eos_value = self.config.data.audio_eos_value audio_pad_value = self.config.data.audio_pad_value delay_pattern = self.config.data.delay_pattern max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens delay_tensor = torch.tensor(delay_pattern, dtype=torch.long, device=self.device) max_delay_pattern = max(delay_pattern) self.model.eval() ( cond_src_BxS, cond_src_positions_BxS, cond_src_padding_mask_BxS, cond_enc_self_attn_mask_Bx1xSxS, ) = self._prepare_text_input(text) unc_src_BxS = torch.zeros_like(cond_src_BxS) src_BxS = torch.cat([unc_src_BxS, cond_src_BxS], dim=0) src_positions_BxS = cond_src_positions_BxS.expand(2, -1) src_padding_mask_BxS = cond_src_padding_mask_BxS.expand(2, -1) enc_self_attn_mask_Bx1xSxS = cond_enc_self_attn_mask_Bx1xSxS.expand(2, -1, -1, -1) # 2. Encoder Pass # with torch.autocast(device_type="cuda", dtype=forward_dtype): encoder_out = self.model.encoder( x_ids=src_BxS, src_positions=src_positions_BxS, deterministic=True, attn_mask=enc_self_attn_mask_Bx1xSxS, ) # Shape: (B, S, E) # 3. Prepare Decoder Inputs # 3-1. Allocate KV Cache (Static) decoder_cross_attention_cache: list[KVCache] = self.model.decoder.precompute_cross_attention_kv( max_tokens, encoder_out, src_positions_BxS ) decoder_self_attention_cache: list[KVCache] = [] for _ in range(self.model.decoder.num_layers): decoder_self_attention_cache.append( KVCache( self.config.model.decoder.gqa_query_heads, max_tokens, self.config.model.decoder.gqa_head_dim, self.device, ) ) # 3-2. Initialize Decoder Inputs generated_BxTxC = torch.full( (2, 1, num_channels), fill_value=audio_bos_value, dtype=torch.long, device=self.device, ) current_step = 0 prompt_len_inc_bos = 1 # Start with BOS length # 3-3. Load Audio Prompt (if provided) if audio_prompt_path is not None: audio_prompt, sr = torchaudio.load(audio_prompt_path, channels_first=True) # C, T if sr != 44100: # Resample to 44.1kHz audio_prompt = torchaudio.functional.resample(audio_prompt, sr, 44100) audio_prompt = audio_prompt.to(self.device).unsqueeze(0) # 1, C, T audio_prompt = audio_to_codebook(self.dac_model, audio_prompt, data_config=self.config.data) generated_BxTxC = torch.cat([generated_BxTxC, audio_prompt.expand(2, -1, -1)], dim=1) prefill_len = generated_BxTxC.shape[1] prompt_len_inc_bos = prefill_len prefill_tgt_pos = torch.arange(prefill_len, device=self.device).unsqueeze(0).expand(2, -1) prefill_tgt_padding_mask = (generated_BxTxC != audio_pad_value).any(dim=2) prefill_self_attn_mask = self._create_attn_mask( prefill_tgt_padding_mask, prefill_tgt_padding_mask, is_causal=True, ) prefill_cross_attn_mask = self._create_attn_mask( prefill_tgt_padding_mask, src_padding_mask_BxS, is_causal=False, ) _ = self.model.decoder.forward( tgt_ids_BxTxC=generated_BxTxC, encoder_out=encoder_out, tgt_positions=prefill_tgt_pos, src_positions=src_positions_BxS, deterministic=True, self_attn_mask=prefill_self_attn_mask, cross_attn_mask=prefill_cross_attn_mask, self_attention_cache=decoder_self_attention_cache, cross_attention_cache=decoder_cross_attention_cache, ) current_step = prefill_len - 1 # 4. Autoregressive Generation Loop eos_detected_channel_0 = False eos_countdown = -1 extra_steps_after_eos = 30 # Make generated_BxTxC a fixed size tensor # Length is either 1 + max tokens or 1 + prompt len + max tokens generated_BxTxC = torch.cat( [ generated_BxTxC, torch.full( (2, max_tokens, num_channels), fill_value=-1, dtype=torch.long, device=self.device, ), ], dim=1, ) decode_step = self.model.decoder.decode_step if use_torch_compile: decode_step = torch.compile( self.model.decoder.decode_step, mode="default", ) tgt_padding_mask = ( (generated_BxTxC[:, -1, :].unsqueeze(1) != audio_pad_value).any(dim=2).to(self.device) ) # [B, 1] # Generated tokens are never PAD, so we use fixed mask decoder_cross_attn_mask = self._create_attn_mask( tgt_padding_mask, # Query mask [B, 1] src_padding_mask_BxS, # Key mask [B, S] is_causal=False, ) # [B, 1, 1, S] for step in range(current_step, current_step + max_tokens): tgt_ids_Bx1xC = generated_BxTxC[:, step, :].unsqueeze(1) tgt_pos_Bx1 = torch.full( (2, 1), fill_value=step, dtype=torch.long, device=self.device, ) logits_Bx1xCxV, new_cache = decode_step( tgt_ids_Bx1xC=tgt_ids_Bx1xC, tgt_pos_Bx1=tgt_pos_Bx1, encoder_out=encoder_out, self_attn_mask=None, cross_attn_mask=decoder_cross_attn_mask, self_attention_cache=decoder_self_attention_cache, cross_attention_cache=decoder_cross_attention_cache, ) for i, layer_cache in enumerate(decoder_self_attention_cache): layer_cache.update_cache(new_cache[i][0], new_cache[i][1]) V = self.config.model.tgt_vocab_size logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] # B, C, V uncond_logits_CxV = logits_last_BxCxV[0, :, :] cond_logits_CxV = logits_last_BxCxV[1, :, :] cfg_logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV) logits_CxV = cfg_logits_CxV.reshape((-1, V)) # C, V logits_CxV[:, 1025:] = -torch.inf # Sample next token pred_C = _sample_next_token( logits_CxV.float(), temperature=temperature, top_p=top_p, use_cfg_filter=use_cfg_filter, cfg_filter_top_k=cfg_filter_top_k, ) generation_step_index = step - current_step if audio_prompt_path is None: pred_C = torch.where( generation_step_index >= delay_tensor, pred_C, audio_bos_value, ) generated_BxTxC[:, step + 1, :] = pred_C.unsqueeze(0).expand(2, -1) if not eos_detected_channel_0 and pred_C[0] == audio_eos_value: eos_detected_channel_0 = True eos_countdown = extra_steps_after_eos if eos_countdown > 0: step_after_eos = max_delay_pattern - eos_countdown for i, d in enumerate(delay_pattern): if step_after_eos == d: generated_BxTxC[:, step + 1, i] = audio_eos_value elif step_after_eos > d: generated_BxTxC[:, step + 1, i] = audio_pad_value eos_countdown -= 1 if eos_countdown == 0: break generation_step_index = step - current_step + 1 output_codes = generated_BxTxC[:, prompt_len_inc_bos : step + 1, :] generated_codes = output_codes[0] audio = codebook_to_audio( generated_codes.transpose(1, 0), self.dac_model, delay_pattern, B=1, T=max_tokens, C=num_channels ) return audio.squeeze().cpu().numpy()