|
import time |
|
from enum import Enum |
|
|
|
import dac |
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
from huggingface_hub import hf_hub_download |
|
|
|
from .audio import ( |
|
apply_audio_delay, |
|
build_delay_indices, |
|
build_revert_indices, |
|
decode, |
|
revert_audio_delay, |
|
) |
|
from .config import DiaConfig |
|
from .layers import DiaModel |
|
from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState |
|
|
|
|
|
DEFAULT_SAMPLE_RATE = 44100 |
|
|
|
|
|
def _get_default_device(): |
|
if torch.cuda.is_available(): |
|
return torch.device("cuda") |
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
return torch.device("mps") |
|
return torch.device("cpu") |
|
|
|
|
|
def _sample_next_token( |
|
logits_BCxV: torch.Tensor, |
|
temperature: float, |
|
top_p: float, |
|
cfg_filter_top_k: int | None = None, |
|
) -> torch.Tensor: |
|
if temperature == 0.0: |
|
return torch.argmax(logits_BCxV, dim=-1) |
|
|
|
logits_BCxV = logits_BCxV / temperature |
|
if cfg_filter_top_k is not None: |
|
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1) |
|
mask = torch.ones_like(logits_BCxV, dtype=torch.bool) |
|
mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False) |
|
logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf) |
|
|
|
if top_p < 1.0: |
|
probs_BCxV = torch.softmax(logits_BCxV, dim=-1) |
|
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort( |
|
probs_BCxV, dim=-1, descending=True |
|
) |
|
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1) |
|
|
|
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p |
|
sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[ |
|
..., :-1 |
|
].clone() |
|
sorted_indices_to_remove_BCxV[..., 0] = 0 |
|
|
|
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV) |
|
indices_to_remove_BCxV.scatter_( |
|
dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV |
|
) |
|
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf) |
|
|
|
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1) |
|
|
|
sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1) |
|
sampled_indices_C = sampled_indices_BC.squeeze(-1) |
|
return sampled_indices_C |
|
|
|
|
|
class ComputeDtype(str, Enum): |
|
FLOAT32 = "float32" |
|
FLOAT16 = "float16" |
|
BFLOAT16 = "bfloat16" |
|
|
|
def to_dtype(self) -> torch.dtype: |
|
if self == ComputeDtype.FLOAT32: |
|
return torch.float32 |
|
elif self == ComputeDtype.FLOAT16: |
|
return torch.float16 |
|
elif self == ComputeDtype.BFLOAT16: |
|
return torch.bfloat16 |
|
else: |
|
raise ValueError(f"Unsupported compute dtype: {self}") |
|
|
|
|
|
class Dia: |
|
def __init__( |
|
self, |
|
config: DiaConfig, |
|
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, |
|
device: torch.device | None = None, |
|
): |
|
"""Initializes the Dia model. |
|
|
|
Args: |
|
config: The configuration object for the model. |
|
device: The device to load the model onto. If None, will automatically select the best available device. |
|
|
|
Raises: |
|
RuntimeError: If there is an error loading the DAC model. |
|
""" |
|
super().__init__() |
|
self.config = config |
|
self.device = device if device is not None else _get_default_device() |
|
if isinstance(compute_dtype, str): |
|
compute_dtype = ComputeDtype(compute_dtype) |
|
self.compute_dtype = compute_dtype.to_dtype() |
|
self.model = DiaModel(config, self.compute_dtype) |
|
self.dac_model = None |
|
|
|
@classmethod |
|
def from_local( |
|
cls, |
|
config_path: str, |
|
checkpoint_path: str, |
|
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, |
|
device: torch.device | None = None, |
|
) -> "Dia": |
|
"""Loads the Dia model from local configuration and checkpoint files. |
|
|
|
Args: |
|
config_path: Path to the configuration JSON file. |
|
checkpoint_path: Path to the model checkpoint (.pth) file. |
|
device: The device to load the model onto. If None, will automatically select the best available device. |
|
|
|
Returns: |
|
An instance of the Dia model loaded with weights and set to eval mode. |
|
|
|
Raises: |
|
FileNotFoundError: If the config or checkpoint file is not found. |
|
RuntimeError: If there is an error loading the checkpoint. |
|
""" |
|
config = DiaConfig.load(config_path) |
|
if config is None: |
|
raise FileNotFoundError(f"Config file not found at {config_path}") |
|
|
|
dia = cls(config, compute_dtype, device) |
|
|
|
try: |
|
state_dict = torch.load(checkpoint_path, map_location=dia.device) |
|
dia.model.load_state_dict(state_dict) |
|
except FileNotFoundError: |
|
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}") |
|
except Exception as e: |
|
raise RuntimeError( |
|
f"Error loading checkpoint from {checkpoint_path}" |
|
) from e |
|
|
|
dia.model.to(dia.device) |
|
dia.model.eval() |
|
dia._load_dac_model() |
|
return dia |
|
|
|
@classmethod |
|
def from_pretrained( |
|
cls, |
|
model_name: str = "nari-labs/Dia-1.6B", |
|
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, |
|
device: torch.device | None = None, |
|
) -> "Dia": |
|
"""Loads the Dia model from a Hugging Face Hub repository. |
|
|
|
Downloads the configuration and checkpoint files from the specified |
|
repository ID and then loads the model. |
|
|
|
Args: |
|
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B"). |
|
device: The device to load the model onto. If None, will automatically select the best available device. |
|
|
|
Returns: |
|
An instance of the Dia model loaded with weights and set to eval mode. |
|
|
|
Raises: |
|
FileNotFoundError: If config or checkpoint download/loading fails. |
|
RuntimeError: If there is an error loading the checkpoint. |
|
""" |
|
config_path = hf_hub_download(repo_id=model_name, filename="config.json") |
|
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth") |
|
return cls.from_local(config_path, checkpoint_path, compute_dtype, device) |
|
|
|
def _load_dac_model(self): |
|
try: |
|
dac_model_path = dac.utils.download() |
|
dac_model = dac.DAC.load(dac_model_path).to(self.device) |
|
except Exception as e: |
|
raise RuntimeError("Failed to load DAC model") from e |
|
self.dac_model = dac_model |
|
|
|
def _prepare_text_input(self, text: str) -> torch.Tensor: |
|
"""Encodes text prompt, pads, and creates attention mask and positions.""" |
|
text_pad_value = self.config.data.text_pad_value |
|
max_len = self.config.data.text_length |
|
|
|
byte_text = text.encode("utf-8") |
|
replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02") |
|
text_tokens = list(replaced_bytes) |
|
|
|
current_len = len(text_tokens) |
|
padding_needed = max_len - current_len |
|
if padding_needed <= 0: |
|
text_tokens = text_tokens[:max_len] |
|
padded_text_np = np.array(text_tokens, dtype=np.uint8) |
|
else: |
|
padded_text_np = np.pad( |
|
text_tokens, |
|
(0, padding_needed), |
|
mode="constant", |
|
constant_values=text_pad_value, |
|
).astype(np.uint8) |
|
|
|
src_tokens = ( |
|
torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) |
|
) |
|
return src_tokens |
|
|
|
def _prepare_audio_prompt( |
|
self, audio_prompt: torch.Tensor | None |
|
) -> tuple[torch.Tensor, int]: |
|
num_channels = self.config.data.channels |
|
audio_bos_value = self.config.data.audio_bos_value |
|
audio_pad_value = self.config.data.audio_pad_value |
|
delay_pattern = self.config.data.delay_pattern |
|
max_delay_pattern = max(delay_pattern) |
|
|
|
prefill = torch.full( |
|
(1, num_channels), |
|
fill_value=audio_bos_value, |
|
dtype=torch.int, |
|
device=self.device, |
|
) |
|
|
|
prefill_step = 1 |
|
|
|
if audio_prompt is not None: |
|
prefill_step += audio_prompt.shape[0] |
|
prefill = torch.cat([prefill, audio_prompt], dim=0) |
|
|
|
delay_pad_tensor = torch.full( |
|
(max_delay_pattern, num_channels), |
|
fill_value=-1, |
|
dtype=torch.int, |
|
device=self.device, |
|
) |
|
prefill = torch.cat([prefill, delay_pad_tensor], dim=0) |
|
|
|
delay_precomp = build_delay_indices( |
|
B=1, |
|
T=prefill.shape[0], |
|
C=num_channels, |
|
delay_pattern=delay_pattern, |
|
) |
|
|
|
prefill = apply_audio_delay( |
|
audio_BxTxC=prefill.unsqueeze(0), |
|
pad_value=audio_pad_value, |
|
bos_value=audio_bos_value, |
|
precomp=delay_precomp, |
|
).squeeze(0) |
|
|
|
return prefill, prefill_step |
|
|
|
def _prepare_generation( |
|
self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool |
|
): |
|
enc_input_cond = self._prepare_text_input(text) |
|
enc_input_uncond = torch.zeros_like(enc_input_cond) |
|
enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0) |
|
|
|
if isinstance(audio_prompt, str): |
|
audio_prompt = self.load_audio(audio_prompt) |
|
prefill, prefill_step = self._prepare_audio_prompt(audio_prompt) |
|
|
|
if verbose: |
|
print("generate: data loaded") |
|
|
|
enc_state = EncoderInferenceState.new(self.config, enc_input_cond) |
|
encoder_out = self.model.encoder(enc_input, enc_state) |
|
|
|
dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache( |
|
encoder_out, enc_state.positions |
|
) |
|
dec_state = DecoderInferenceState.new( |
|
self.config, |
|
enc_state, |
|
encoder_out, |
|
dec_cross_attn_cache, |
|
self.compute_dtype, |
|
) |
|
dec_output = DecoderOutput.new(self.config, self.device) |
|
dec_output.prefill(prefill, prefill_step) |
|
|
|
dec_step = prefill_step - 1 |
|
if dec_step > 0: |
|
dec_state.prepare_step(0, dec_step) |
|
tokens_BxTxC = ( |
|
dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1) |
|
) |
|
self.model.decoder.forward(tokens_BxTxC, dec_state) |
|
|
|
return dec_state, dec_output |
|
|
|
def _decoder_step( |
|
self, |
|
tokens_Bx1xC: torch.Tensor, |
|
dec_state: DecoderInferenceState, |
|
cfg_scale: float, |
|
temperature: float, |
|
top_p: float, |
|
cfg_filter_top_k: int, |
|
) -> torch.Tensor: |
|
audio_eos_value = self.config.data.audio_eos_value |
|
logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state) |
|
|
|
logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] |
|
uncond_logits_CxV = logits_last_BxCxV[0, :, :] |
|
cond_logits_CxV = logits_last_BxCxV[1, :, :] |
|
|
|
logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV) |
|
logits_CxV[:, audio_eos_value + 1 :] = -torch.inf |
|
logits_CxV[1:, audio_eos_value:] = -torch.inf |
|
|
|
pred_C = _sample_next_token( |
|
logits_CxV.float(), |
|
temperature=temperature, |
|
top_p=top_p, |
|
cfg_filter_top_k=cfg_filter_top_k, |
|
) |
|
return pred_C |
|
|
|
def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray: |
|
num_channels = self.config.data.channels |
|
seq_length = generated_codes.shape[0] |
|
delay_pattern = self.config.data.delay_pattern |
|
audio_pad_value = self.config.data.audio_pad_value |
|
max_delay_pattern = max(delay_pattern) |
|
|
|
revert_precomp = build_revert_indices( |
|
B=1, |
|
T=seq_length, |
|
C=num_channels, |
|
delay_pattern=delay_pattern, |
|
) |
|
|
|
codebook = revert_audio_delay( |
|
audio_BxTxC=generated_codes.unsqueeze(0), |
|
pad_value=audio_pad_value, |
|
precomp=revert_precomp, |
|
T=seq_length, |
|
)[:, :-max_delay_pattern, :] |
|
|
|
min_valid_index = 0 |
|
max_valid_index = 1023 |
|
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index) |
|
codebook[invalid_mask] = 0 |
|
|
|
audio = decode(self.dac_model, codebook.transpose(1, 2)) |
|
|
|
return audio.squeeze().cpu().numpy() |
|
|
|
def load_audio(self, audio_path: str) -> torch.Tensor: |
|
audio, sr = torchaudio.load(audio_path, channels_first=True) |
|
if sr != DEFAULT_SAMPLE_RATE: |
|
audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE) |
|
audio = audio.to(self.device).unsqueeze(0) |
|
audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE) |
|
_, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) |
|
return encoded_frame.squeeze(0).transpose(0, 1) |
|
|
|
def save_audio(self, path: str, audio: np.ndarray): |
|
import soundfile as sf |
|
|
|
sf.write(path, audio, DEFAULT_SAMPLE_RATE) |
|
|
|
@torch.inference_mode() |
|
def generate( |
|
self, |
|
text: str, |
|
max_tokens: int | None = None, |
|
cfg_scale: float = 3.0, |
|
temperature: float = 1.3, |
|
top_p: float = 0.95, |
|
use_torch_compile: bool = False, |
|
cfg_filter_top_k: int = 35, |
|
audio_prompt: str | torch.Tensor | None = None, |
|
audio_prompt_path: str | None = None, |
|
use_cfg_filter: bool | None = None, |
|
verbose: bool = False, |
|
) -> np.ndarray: |
|
audio_eos_value = self.config.data.audio_eos_value |
|
audio_pad_value = self.config.data.audio_pad_value |
|
delay_pattern = self.config.data.delay_pattern |
|
max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens |
|
max_delay_pattern = max(delay_pattern) |
|
self.model.eval() |
|
|
|
if audio_prompt_path: |
|
print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.") |
|
audio_prompt = audio_prompt_path |
|
if use_cfg_filter is not None: |
|
print("Warning: use_cfg_filter is deprecated.") |
|
|
|
if verbose: |
|
total_start_time = time.time() |
|
|
|
dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose) |
|
dec_step = dec_output.prefill_step - 1 |
|
|
|
bos_countdown = max_delay_pattern |
|
eos_detected = False |
|
eos_countdown = -1 |
|
|
|
if use_torch_compile: |
|
step_fn = torch.compile(self._decoder_step, mode="default") |
|
else: |
|
step_fn = self._decoder_step |
|
|
|
if verbose: |
|
print("generate: starting generation loop") |
|
if use_torch_compile: |
|
print( |
|
"generate: by using use_torch_compile=True, the first step would take long" |
|
) |
|
start_time = time.time() |
|
|
|
while dec_step < max_tokens: |
|
dec_state.prepare_step(dec_step) |
|
tokens_Bx1xC = ( |
|
dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1) |
|
) |
|
pred_C = step_fn( |
|
tokens_Bx1xC, |
|
dec_state, |
|
cfg_scale, |
|
temperature, |
|
top_p, |
|
cfg_filter_top_k, |
|
) |
|
|
|
if ( |
|
not eos_detected and pred_C[0] == audio_eos_value |
|
) or dec_step == max_tokens - max_delay_pattern - 1: |
|
eos_detected = True |
|
eos_countdown = max_delay_pattern |
|
|
|
if eos_countdown > 0: |
|
step_after_eos = max_delay_pattern - eos_countdown |
|
for i, d in enumerate(delay_pattern): |
|
if step_after_eos == d: |
|
pred_C[i] = audio_eos_value |
|
elif step_after_eos > d: |
|
pred_C[i] = audio_pad_value |
|
eos_countdown -= 1 |
|
|
|
bos_countdown = max(0, bos_countdown - 1) |
|
dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0) |
|
|
|
if eos_countdown == 0: |
|
break |
|
|
|
dec_step += 1 |
|
if verbose and dec_step % 86 == 0: |
|
duration = time.time() - start_time |
|
print( |
|
f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x" |
|
) |
|
start_time = time.time() |
|
|
|
if dec_output.prefill_step >= dec_step + 1: |
|
print("Warning: Nothing generated") |
|
return None |
|
|
|
generated_codes = dec_output.generated_tokens[ |
|
dec_output.prefill_step : dec_step + 1, : |
|
] |
|
|
|
if verbose: |
|
total_step = dec_step + 1 - dec_output.prefill_step |
|
total_duration = time.time() - total_start_time |
|
print( |
|
f"generate: total step={total_step}, total duration={total_duration:.3f}s" |
|
) |
|
|
|
return self._generate_output(generated_codes) |
|
|