Spaces:
Running
on
Zero
Running
on
Zero
import time | |
from enum import Enum | |
import dac | |
import numpy as np | |
import torch | |
import torchaudio | |
from huggingface_hub import hf_hub_download | |
from .audio import ( | |
apply_audio_delay, | |
build_delay_indices, | |
build_revert_indices, | |
decode, | |
revert_audio_delay, | |
) | |
from .config import DiaConfig | |
from .layers import DiaModel | |
from .state import DecoderInferenceState, DecoderOutput, EncoderInferenceState | |
DEFAULT_SAMPLE_RATE = 44100 | |
def _get_default_device(): | |
if torch.cuda.is_available(): | |
return torch.device("cuda") | |
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): | |
return torch.device("mps") | |
return torch.device("cpu") | |
def _sample_next_token( | |
logits_BCxV: torch.Tensor, | |
temperature: float, | |
top_p: float, | |
cfg_filter_top_k: int | None = None, | |
) -> torch.Tensor: | |
if temperature == 0.0: | |
return torch.argmax(logits_BCxV, dim=-1) | |
logits_BCxV = logits_BCxV / temperature | |
if cfg_filter_top_k is not None: | |
_, top_k_indices_BCxV = torch.topk(logits_BCxV, k=cfg_filter_top_k, dim=-1) | |
mask = torch.ones_like(logits_BCxV, dtype=torch.bool) | |
mask.scatter_(dim=-1, index=top_k_indices_BCxV, value=False) | |
logits_BCxV = logits_BCxV.masked_fill(mask, -torch.inf) | |
if top_p < 1.0: | |
probs_BCxV = torch.softmax(logits_BCxV, dim=-1) | |
sorted_probs_BCxV, sorted_indices_BCxV = torch.sort( | |
probs_BCxV, dim=-1, descending=True | |
) | |
cumulative_probs_BCxV = torch.cumsum(sorted_probs_BCxV, dim=-1) | |
sorted_indices_to_remove_BCxV = cumulative_probs_BCxV > top_p | |
sorted_indices_to_remove_BCxV[..., 1:] = sorted_indices_to_remove_BCxV[ | |
..., :-1 | |
].clone() | |
sorted_indices_to_remove_BCxV[..., 0] = 0 | |
indices_to_remove_BCxV = torch.zeros_like(sorted_indices_to_remove_BCxV) | |
indices_to_remove_BCxV.scatter_( | |
dim=-1, index=sorted_indices_BCxV, src=sorted_indices_to_remove_BCxV | |
) | |
logits_BCxV = logits_BCxV.masked_fill(indices_to_remove_BCxV, -torch.inf) | |
final_probs_BCxV = torch.softmax(logits_BCxV, dim=-1) | |
sampled_indices_BC = torch.multinomial(final_probs_BCxV, num_samples=1) | |
sampled_indices_C = sampled_indices_BC.squeeze(-1) | |
return sampled_indices_C | |
class ComputeDtype(str, Enum): | |
FLOAT32 = "float32" | |
FLOAT16 = "float16" | |
BFLOAT16 = "bfloat16" | |
def to_dtype(self) -> torch.dtype: | |
if self == ComputeDtype.FLOAT32: | |
return torch.float32 | |
elif self == ComputeDtype.FLOAT16: | |
return torch.float16 | |
elif self == ComputeDtype.BFLOAT16: | |
return torch.bfloat16 | |
else: | |
raise ValueError(f"Unsupported compute dtype: {self}") | |
class Dia: | |
def __init__( | |
self, | |
config: DiaConfig, | |
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, | |
device: torch.device | None = None, | |
): | |
"""Initializes the Dia model. | |
Args: | |
config: The configuration object for the model. | |
device: The device to load the model onto. If None, will automatically select the best available device. | |
Raises: | |
RuntimeError: If there is an error loading the DAC model. | |
""" | |
super().__init__() | |
self.config = config | |
self.device = device if device is not None else _get_default_device() | |
if isinstance(compute_dtype, str): | |
compute_dtype = ComputeDtype(compute_dtype) | |
self.compute_dtype = compute_dtype.to_dtype() | |
self.model = DiaModel(config, self.compute_dtype) | |
self.dac_model = None | |
def from_local( | |
cls, | |
config_path: str, | |
checkpoint_path: str, | |
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, | |
device: torch.device | None = None, | |
) -> "Dia": | |
"""Loads the Dia model from local configuration and checkpoint files. | |
Args: | |
config_path: Path to the configuration JSON file. | |
checkpoint_path: Path to the model checkpoint (.pth) file. | |
device: The device to load the model onto. If None, will automatically select the best available device. | |
Returns: | |
An instance of the Dia model loaded with weights and set to eval mode. | |
Raises: | |
FileNotFoundError: If the config or checkpoint file is not found. | |
RuntimeError: If there is an error loading the checkpoint. | |
""" | |
config = DiaConfig.load(config_path) | |
if config is None: | |
raise FileNotFoundError(f"Config file not found at {config_path}") | |
dia = cls(config, compute_dtype, device) | |
try: | |
state_dict = torch.load(checkpoint_path, map_location=dia.device) | |
dia.model.load_state_dict(state_dict) | |
except FileNotFoundError: | |
raise FileNotFoundError(f"Checkpoint file not found at {checkpoint_path}") | |
except Exception as e: | |
raise RuntimeError( | |
f"Error loading checkpoint from {checkpoint_path}" | |
) from e | |
dia.model.to(dia.device) | |
dia.model.eval() | |
dia._load_dac_model() | |
return dia | |
def from_pretrained( | |
cls, | |
model_name: str = "nari-labs/Dia-1.6B", | |
compute_dtype: str | ComputeDtype = ComputeDtype.FLOAT32, | |
device: torch.device | None = None, | |
) -> "Dia": | |
"""Loads the Dia model from a Hugging Face Hub repository. | |
Downloads the configuration and checkpoint files from the specified | |
repository ID and then loads the model. | |
Args: | |
model_name: The Hugging Face Hub repository ID (e.g., "NariLabs/Dia-1.6B"). | |
device: The device to load the model onto. If None, will automatically select the best available device. | |
Returns: | |
An instance of the Dia model loaded with weights and set to eval mode. | |
Raises: | |
FileNotFoundError: If config or checkpoint download/loading fails. | |
RuntimeError: If there is an error loading the checkpoint. | |
""" | |
config_path = hf_hub_download(repo_id=model_name, filename="config.json") | |
checkpoint_path = hf_hub_download(repo_id=model_name, filename="dia-v0_1.pth") | |
return cls.from_local(config_path, checkpoint_path, compute_dtype, device) | |
def _load_dac_model(self): | |
try: | |
dac_model_path = dac.utils.download() | |
dac_model = dac.DAC.load(dac_model_path).to(self.device) | |
except Exception as e: | |
raise RuntimeError("Failed to load DAC model") from e | |
self.dac_model = dac_model | |
def _prepare_text_input(self, text: str) -> torch.Tensor: | |
"""Encodes text prompt, pads, and creates attention mask and positions.""" | |
text_pad_value = self.config.data.text_pad_value | |
max_len = self.config.data.text_length | |
byte_text = text.encode("utf-8") | |
replaced_bytes = byte_text.replace(b"[S1]", b"\x01").replace(b"[S2]", b"\x02") | |
text_tokens = list(replaced_bytes) | |
current_len = len(text_tokens) | |
padding_needed = max_len - current_len | |
if padding_needed <= 0: | |
text_tokens = text_tokens[:max_len] | |
padded_text_np = np.array(text_tokens, dtype=np.uint8) | |
else: | |
padded_text_np = np.pad( | |
text_tokens, | |
(0, padding_needed), | |
mode="constant", | |
constant_values=text_pad_value, | |
).astype(np.uint8) | |
src_tokens = ( | |
torch.from_numpy(padded_text_np).to(torch.long).to(self.device).unsqueeze(0) | |
) # [1, S] | |
return src_tokens | |
def _prepare_audio_prompt( | |
self, audio_prompt: torch.Tensor | None | |
) -> tuple[torch.Tensor, int]: | |
num_channels = self.config.data.channels | |
audio_bos_value = self.config.data.audio_bos_value | |
audio_pad_value = self.config.data.audio_pad_value | |
delay_pattern = self.config.data.delay_pattern | |
max_delay_pattern = max(delay_pattern) | |
prefill = torch.full( | |
(1, num_channels), | |
fill_value=audio_bos_value, | |
dtype=torch.int, | |
device=self.device, | |
) | |
prefill_step = 1 | |
if audio_prompt is not None: | |
prefill_step += audio_prompt.shape[0] | |
prefill = torch.cat([prefill, audio_prompt], dim=0) | |
delay_pad_tensor = torch.full( | |
(max_delay_pattern, num_channels), | |
fill_value=-1, | |
dtype=torch.int, | |
device=self.device, | |
) | |
prefill = torch.cat([prefill, delay_pad_tensor], dim=0) | |
delay_precomp = build_delay_indices( | |
B=1, | |
T=prefill.shape[0], | |
C=num_channels, | |
delay_pattern=delay_pattern, | |
) | |
prefill = apply_audio_delay( | |
audio_BxTxC=prefill.unsqueeze(0), | |
pad_value=audio_pad_value, | |
bos_value=audio_bos_value, | |
precomp=delay_precomp, | |
).squeeze(0) | |
return prefill, prefill_step | |
def _prepare_generation( | |
self, text: str, audio_prompt: str | torch.Tensor | None, verbose: bool | |
): | |
enc_input_cond = self._prepare_text_input(text) | |
enc_input_uncond = torch.zeros_like(enc_input_cond) | |
enc_input = torch.cat([enc_input_uncond, enc_input_cond], dim=0) | |
if isinstance(audio_prompt, str): | |
audio_prompt = self.load_audio(audio_prompt) | |
prefill, prefill_step = self._prepare_audio_prompt(audio_prompt) | |
if verbose: | |
print("generate: data loaded") | |
enc_state = EncoderInferenceState.new(self.config, enc_input_cond) | |
encoder_out = self.model.encoder(enc_input, enc_state) | |
dec_cross_attn_cache = self.model.decoder.precompute_cross_attn_cache( | |
encoder_out, enc_state.positions | |
) | |
dec_state = DecoderInferenceState.new( | |
self.config, | |
enc_state, | |
encoder_out, | |
dec_cross_attn_cache, | |
self.compute_dtype, | |
) | |
dec_output = DecoderOutput.new(self.config, self.device) | |
dec_output.prefill(prefill, prefill_step) | |
dec_step = prefill_step - 1 | |
if dec_step > 0: | |
dec_state.prepare_step(0, dec_step) | |
tokens_BxTxC = ( | |
dec_output.get_tokens_at(0, dec_step).unsqueeze(0).expand(2, -1, -1) | |
) | |
self.model.decoder.forward(tokens_BxTxC, dec_state) | |
return dec_state, dec_output | |
def _decoder_step( | |
self, | |
tokens_Bx1xC: torch.Tensor, | |
dec_state: DecoderInferenceState, | |
cfg_scale: float, | |
temperature: float, | |
top_p: float, | |
cfg_filter_top_k: int, | |
) -> torch.Tensor: | |
audio_eos_value = self.config.data.audio_eos_value | |
logits_Bx1xCxV = self.model.decoder.decode_step(tokens_Bx1xC, dec_state) | |
logits_last_BxCxV = logits_Bx1xCxV[:, -1, :, :] | |
uncond_logits_CxV = logits_last_BxCxV[0, :, :] | |
cond_logits_CxV = logits_last_BxCxV[1, :, :] | |
logits_CxV = cond_logits_CxV + cfg_scale * (cond_logits_CxV - uncond_logits_CxV) | |
logits_CxV[:, audio_eos_value + 1 :] = -torch.inf | |
logits_CxV[1:, audio_eos_value:] = -torch.inf | |
pred_C = _sample_next_token( | |
logits_CxV.float(), | |
temperature=temperature, | |
top_p=top_p, | |
cfg_filter_top_k=cfg_filter_top_k, | |
) | |
return pred_C | |
def _generate_output(self, generated_codes: torch.Tensor) -> np.ndarray: | |
num_channels = self.config.data.channels | |
seq_length = generated_codes.shape[0] | |
delay_pattern = self.config.data.delay_pattern | |
audio_pad_value = self.config.data.audio_pad_value | |
max_delay_pattern = max(delay_pattern) | |
revert_precomp = build_revert_indices( | |
B=1, | |
T=seq_length, | |
C=num_channels, | |
delay_pattern=delay_pattern, | |
) | |
codebook = revert_audio_delay( | |
audio_BxTxC=generated_codes.unsqueeze(0), | |
pad_value=audio_pad_value, | |
precomp=revert_precomp, | |
T=seq_length, | |
)[:, :-max_delay_pattern, :] | |
min_valid_index = 0 | |
max_valid_index = 1023 | |
invalid_mask = (codebook < min_valid_index) | (codebook > max_valid_index) | |
codebook[invalid_mask] = 0 | |
audio = decode(self.dac_model, codebook.transpose(1, 2)) | |
return audio.squeeze().cpu().numpy() | |
def load_audio(self, audio_path: str) -> torch.Tensor: | |
audio, sr = torchaudio.load(audio_path, channels_first=True) # C, T | |
if sr != DEFAULT_SAMPLE_RATE: | |
audio = torchaudio.functional.resample(audio, sr, DEFAULT_SAMPLE_RATE) | |
audio = audio.to(self.device).unsqueeze(0) # 1, C, T | |
audio_data = self.dac_model.preprocess(audio, DEFAULT_SAMPLE_RATE) | |
_, encoded_frame, _, _, _ = self.dac_model.encode(audio_data) # 1, C, T | |
return encoded_frame.squeeze(0).transpose(0, 1) | |
def save_audio(self, path: str, audio: np.ndarray): | |
import soundfile as sf | |
sf.write(path, audio, DEFAULT_SAMPLE_RATE) | |
def generate( | |
self, | |
text: str, | |
max_tokens: int | None = None, | |
cfg_scale: float = 3.0, | |
temperature: float = 1.3, | |
top_p: float = 0.95, | |
use_torch_compile: bool = False, | |
cfg_filter_top_k: int = 35, | |
audio_prompt: str | torch.Tensor | None = None, | |
audio_prompt_path: str | None = None, | |
use_cfg_filter: bool | None = None, | |
verbose: bool = False, | |
) -> np.ndarray: | |
audio_eos_value = self.config.data.audio_eos_value | |
audio_pad_value = self.config.data.audio_pad_value | |
delay_pattern = self.config.data.delay_pattern | |
max_tokens = self.config.data.audio_length if max_tokens is None else max_tokens | |
max_delay_pattern = max(delay_pattern) | |
self.model.eval() | |
if audio_prompt_path: | |
print("Warning: audio_prompt_path is deprecated. Use audio_prompt instead.") | |
audio_prompt = audio_prompt_path | |
if use_cfg_filter is not None: | |
print("Warning: use_cfg_filter is deprecated.") | |
if verbose: | |
total_start_time = time.time() | |
dec_state, dec_output = self._prepare_generation(text, audio_prompt, verbose) | |
dec_step = dec_output.prefill_step - 1 | |
bos_countdown = max_delay_pattern | |
eos_detected = False | |
eos_countdown = -1 | |
if use_torch_compile: | |
step_fn = torch.compile(self._decoder_step, mode="default") | |
else: | |
step_fn = self._decoder_step | |
if verbose: | |
print("generate: starting generation loop") | |
if use_torch_compile: | |
print( | |
"generate: by using use_torch_compile=True, the first step would take long" | |
) | |
start_time = time.time() | |
while dec_step < max_tokens: | |
dec_state.prepare_step(dec_step) | |
tokens_Bx1xC = ( | |
dec_output.get_tokens_at(dec_step).unsqueeze(0).expand(2, -1, -1) | |
) | |
pred_C = step_fn( | |
tokens_Bx1xC, | |
dec_state, | |
cfg_scale, | |
temperature, | |
top_p, | |
cfg_filter_top_k, | |
) | |
if ( | |
not eos_detected and pred_C[0] == audio_eos_value | |
) or dec_step == max_tokens - max_delay_pattern - 1: | |
eos_detected = True | |
eos_countdown = max_delay_pattern | |
if eos_countdown > 0: | |
step_after_eos = max_delay_pattern - eos_countdown | |
for i, d in enumerate(delay_pattern): | |
if step_after_eos == d: | |
pred_C[i] = audio_eos_value | |
elif step_after_eos > d: | |
pred_C[i] = audio_pad_value | |
eos_countdown -= 1 | |
bos_countdown = max(0, bos_countdown - 1) | |
dec_output.update_one(pred_C, dec_step + 1, bos_countdown > 0) | |
if eos_countdown == 0: | |
break | |
dec_step += 1 | |
if verbose and dec_step % 86 == 0: | |
duration = time.time() - start_time | |
print( | |
f"generate step {dec_step}: speed={86 / duration:.3f} tokens/s, realtime factor={1 / duration:.3f}x" | |
) | |
start_time = time.time() | |
if dec_output.prefill_step >= dec_step + 1: | |
print("Warning: Nothing generated") | |
return None | |
generated_codes = dec_output.generated_tokens[ | |
dec_output.prefill_step : dec_step + 1, : | |
] | |
if verbose: | |
total_step = dec_step + 1 - dec_output.prefill_step | |
total_duration = time.time() - total_start_time | |
print( | |
f"generate: total step={total_step}, total duration={total_duration:.3f}s" | |
) | |
return self._generate_output(generated_codes) | |