|
import tempfile |
|
import time |
|
from pathlib import Path |
|
from typing import Optional, Tuple |
|
import spaces |
|
|
|
import gradio as gr |
|
import numpy as np |
|
import soundfile as sf |
|
import torch |
|
|
|
from dia.model import Dia |
|
|
|
|
|
|
|
print("Loading Nari model...") |
|
try: |
|
|
|
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32") |
|
except Exception as e: |
|
print(f"Error loading Nari model: {e}") |
|
raise |
|
|
|
|
|
@spaces.GPU |
|
def run_inference( |
|
text_input: str, |
|
audio_prompt_input: Optional[Tuple[int, np.ndarray]], |
|
max_new_tokens: int, |
|
cfg_scale: float, |
|
temperature: float, |
|
top_p: float, |
|
cfg_filter_top_k: int, |
|
speed_factor: float, |
|
): |
|
""" |
|
Runs Nari inference using the globally loaded model and provided inputs. |
|
Uses temporary files for text and audio prompt compatibility with inference.generate. |
|
""" |
|
|
|
|
|
if not text_input or text_input.isspace(): |
|
raise gr.Error("Text input cannot be empty.") |
|
|
|
temp_txt_file_path = None |
|
temp_audio_prompt_path = None |
|
output_audio = (44100, np.zeros(1, dtype=np.float32)) |
|
|
|
try: |
|
prompt_path_for_generate = None |
|
if audio_prompt_input is not None: |
|
sr, audio_data = audio_prompt_input |
|
|
|
if ( |
|
audio_data is None or audio_data.size == 0 or audio_data.max() == 0 |
|
): |
|
gr.Warning("Audio prompt seems empty or silent, ignoring prompt.") |
|
else: |
|
|
|
with tempfile.NamedTemporaryFile( |
|
mode="wb", suffix=".wav", delete=False |
|
) as f_audio: |
|
temp_audio_prompt_path = f_audio.name |
|
|
|
|
|
|
|
if np.issubdtype(audio_data.dtype, np.integer): |
|
max_val = np.iinfo(audio_data.dtype).max |
|
audio_data = audio_data.astype(np.float32) / max_val |
|
elif not np.issubdtype(audio_data.dtype, np.floating): |
|
gr.Warning( |
|
f"Unsupported audio prompt dtype {audio_data.dtype}, attempting conversion." |
|
) |
|
|
|
try: |
|
audio_data = audio_data.astype(np.float32) |
|
except Exception as conv_e: |
|
raise gr.Error( |
|
f"Failed to convert audio prompt to float32: {conv_e}" |
|
) |
|
|
|
|
|
if audio_data.ndim > 1: |
|
if audio_data.shape[0] == 2: |
|
audio_data = np.mean(audio_data, axis=0) |
|
elif audio_data.shape[1] == 2: |
|
audio_data = np.mean(audio_data, axis=1) |
|
else: |
|
gr.Warning( |
|
f"Audio prompt has unexpected shape {audio_data.shape}, taking first channel/axis." |
|
) |
|
audio_data = ( |
|
audio_data[0] |
|
if audio_data.shape[0] < audio_data.shape[1] |
|
else audio_data[:, 0] |
|
) |
|
audio_data = np.ascontiguousarray( |
|
audio_data |
|
) |
|
|
|
|
|
try: |
|
sf.write( |
|
temp_audio_prompt_path, audio_data, sr, subtype="FLOAT" |
|
) |
|
prompt_path_for_generate = temp_audio_prompt_path |
|
print( |
|
f"Created temporary audio prompt file: {temp_audio_prompt_path} (orig sr: {sr})" |
|
) |
|
except Exception as write_e: |
|
print(f"Error writing temporary audio file: {write_e}") |
|
raise gr.Error(f"Failed to save audio prompt: {write_e}") |
|
|
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
with torch.inference_mode(): |
|
output_audio_np = model.generate( |
|
text_input, |
|
max_tokens=max_new_tokens, |
|
cfg_scale=cfg_scale, |
|
temperature=temperature, |
|
top_p=top_p, |
|
cfg_filter_top_k=cfg_filter_top_k, |
|
use_torch_compile=False, |
|
audio_prompt=prompt_path_for_generate, |
|
) |
|
|
|
end_time = time.time() |
|
print(f"Generation finished in {end_time - start_time:.2f} seconds.") |
|
|
|
|
|
if output_audio_np is not None: |
|
|
|
output_sr = 44100 |
|
|
|
|
|
original_len = len(output_audio_np) |
|
|
|
speed_factor = max(0.1, min(speed_factor, 5.0)) |
|
target_len = int( |
|
original_len / speed_factor |
|
) |
|
if ( |
|
target_len != original_len and target_len > 0 |
|
): |
|
x_original = np.arange(original_len) |
|
x_resampled = np.linspace(0, original_len - 1, target_len) |
|
resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np) |
|
output_audio = ( |
|
output_sr, |
|
resampled_audio_np.astype(np.float32), |
|
) |
|
print( |
|
f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed." |
|
) |
|
else: |
|
output_audio = ( |
|
output_sr, |
|
output_audio_np, |
|
) |
|
print(f"Skipping audio speed adjustment (factor: {speed_factor:.2f}).") |
|
|
|
|
|
print( |
|
f"Audio conversion successful. Final shape: {output_audio[1].shape}, Sample Rate: {output_sr}" |
|
) |
|
|
|
|
|
if ( |
|
output_audio[1].dtype == np.float32 |
|
or output_audio[1].dtype == np.float64 |
|
): |
|
audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0) |
|
audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16) |
|
output_audio = (output_sr, audio_for_gradio) |
|
print("Converted audio to int16 for Gradio output.") |
|
|
|
else: |
|
print("\nGeneration finished, but no valid tokens were produced.") |
|
|
|
gr.Warning("Generation produced no output.") |
|
|
|
except Exception as e: |
|
print(f"Error during inference: {e}") |
|
import traceback |
|
|
|
traceback.print_exc() |
|
|
|
raise gr.Error(f"Inference failed: {e}") |
|
|
|
finally: |
|
|
|
if temp_txt_file_path and Path(temp_txt_file_path).exists(): |
|
try: |
|
Path(temp_txt_file_path).unlink() |
|
print(f"Deleted temporary text file: {temp_txt_file_path}") |
|
except OSError as e: |
|
print( |
|
f"Warning: Error deleting temporary text file {temp_txt_file_path}: {e}" |
|
) |
|
if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists(): |
|
try: |
|
Path(temp_audio_prompt_path).unlink() |
|
print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}") |
|
except OSError as e: |
|
print( |
|
f"Warning: Error deleting temporary audio prompt file {temp_audio_prompt_path}: {e}" |
|
) |
|
|
|
return output_audio |
|
|
|
|
|
|
|
css = """ |
|
#col-container {max-width: 90%; margin-left: auto; margin-right: auto;} |
|
""" |
|
|
|
default_text = "[S1] Dia is an open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] Wow. Amazing. (laughs) \n[S2] Try it now on Git hub or Hugging Face." |
|
example_txt_path = Path("./example.txt") |
|
if example_txt_path.exists(): |
|
try: |
|
default_text = example_txt_path.read_text(encoding="utf-8").strip() |
|
if not default_text: |
|
default_text = "Example text file was empty." |
|
except Exception as e: |
|
print(f"Warning: Could not read example.txt: {e}") |
|
|
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("# Nari Text-to-Speech Synthesis") |
|
|
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=1): |
|
text_input = gr.Textbox( |
|
label="Input Text", |
|
placeholder="Enter text here...", |
|
value=default_text, |
|
lines=5, |
|
) |
|
audio_prompt_input = gr.Audio( |
|
label="Audio Prompt (Optional)", |
|
show_label=True, |
|
sources=["upload", "microphone"], |
|
type="numpy", |
|
) |
|
with gr.Accordion("Generation Parameters", open=False): |
|
max_new_tokens = gr.Slider( |
|
label="Max New Tokens (Audio Length)", |
|
minimum=860, |
|
maximum=3072, |
|
value=model.config.data.audio_length, |
|
step=50, |
|
info="Controls the maximum length of the generated audio (more tokens = longer audio).", |
|
) |
|
cfg_scale = gr.Slider( |
|
label="CFG Scale (Guidance Strength)", |
|
minimum=1.0, |
|
maximum=5.0, |
|
value=3.0, |
|
step=0.1, |
|
info="Higher values increase adherence to the text prompt.", |
|
) |
|
temperature = gr.Slider( |
|
label="Temperature (Randomness)", |
|
minimum=1.0, |
|
maximum=1.5, |
|
value=1.3, |
|
step=0.05, |
|
info="Lower values make the output more deterministic, higher values increase randomness.", |
|
) |
|
top_p = gr.Slider( |
|
label="Top P (Nucleus Sampling)", |
|
minimum=0.80, |
|
maximum=1.0, |
|
value=0.95, |
|
step=0.01, |
|
info="Filters vocabulary to the most likely tokens cumulatively reaching probability P.", |
|
) |
|
cfg_filter_top_k = gr.Slider( |
|
label="CFG Filter Top K", |
|
minimum=15, |
|
maximum=50, |
|
value=30, |
|
step=1, |
|
info="Top k filter for CFG guidance.", |
|
) |
|
speed_factor_slider = gr.Slider( |
|
label="Speed Factor", |
|
minimum=0.8, |
|
maximum=1.0, |
|
value=0.94, |
|
step=0.02, |
|
info="Adjusts the speed of the generated audio (1.0 = original speed).", |
|
) |
|
|
|
run_button = gr.Button("Generate Audio", variant="primary") |
|
|
|
with gr.Column(scale=1): |
|
audio_output = gr.Audio( |
|
label="Generated Audio", |
|
type="numpy", |
|
autoplay=False, |
|
) |
|
|
|
|
|
run_button.click( |
|
fn=run_inference, |
|
inputs=[ |
|
text_input, |
|
audio_prompt_input, |
|
max_new_tokens, |
|
cfg_scale, |
|
temperature, |
|
top_p, |
|
cfg_filter_top_k, |
|
speed_factor_slider, |
|
], |
|
outputs=[audio_output], |
|
api_name="generate_audio", |
|
) |
|
|
|
|
|
example_prompt_path = "./example_prompt.mp3" |
|
examples_list = [ |
|
[ |
|
"[S1] Oh fire! Oh my goodness! What's the procedure? What to we do people? The smoke could be coming through an air duct! \n[S2] Oh my god! Okay.. it's happening. Everybody stay calm! \n[S1] What's the procedure... \n[S2] Everybody stay fucking calm!!!... Everybody fucking calm down!!!!! \n[S1] No! No! If you touch the handle, if its hot there might be a fire down the hallway! ", |
|
None, |
|
3072, |
|
3.0, |
|
1.3, |
|
0.95, |
|
35, |
|
0.94, |
|
], |
|
[ |
|
"[S1] Open weights text to dialogue model. \n[S2] You get full control over scripts and voices. \n[S1] I'm biased, but I think we clearly won. \n[S2] Hard to disagree. (laughs) \n[S1] Thanks for listening to this demo. \n[S2] Try it now on Git hub and Hugging Face. \n[S1] If you liked our model, please give us a star and share to your friends. \n[S2] This was Nari Labs.", |
|
example_prompt_path if Path(example_prompt_path).exists() else None, |
|
3072, |
|
3.0, |
|
1.3, |
|
0.95, |
|
35, |
|
0.94, |
|
], |
|
] |
|
|
|
if examples_list: |
|
gr.Examples( |
|
examples=examples_list, |
|
inputs=[ |
|
text_input, |
|
audio_prompt_input, |
|
max_new_tokens, |
|
cfg_scale, |
|
temperature, |
|
top_p, |
|
cfg_filter_top_k, |
|
speed_factor_slider, |
|
], |
|
outputs=[audio_output], |
|
fn=run_inference, |
|
cache_examples=False, |
|
label="Examples (Click to Run)", |
|
) |
|
else: |
|
gr.Markdown("_(No examples configured or example prompt file missing)_") |
|
|
|
|
|
if __name__ == "__main__": |
|
print("Launching Gradio interface...") |
|
|
|
|
|
|
|
demo.launch() |