Michael Hu commited on
Commit
030c851
·
1 Parent(s): 4b0381b

add dia tts model

Browse files
Files changed (1) hide show
  1. utils/tts_dia.py +103 -0
utils/tts_dia.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+ import torch
5
+ import numpy as np
6
+ import soundfile as sf
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+ from dia.model import Dia
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # Constants
17
+ DEFAULT_SAMPLE_RATE = 44100
18
+ DEFAULT_MODEL_NAME = "nari-labs/Dia-1.6B"
19
+
20
+ # Global model instance (lazy loaded)
21
+ _model = None
22
+
23
+
24
+ def _get_model() -> Dia:
25
+ """Lazy-load the Dia model to avoid loading it until needed"""
26
+ global _model
27
+ if _model is None:
28
+ logger.info("Loading Dia model...")
29
+ try:
30
+ _model = Dia.from_pretrained(DEFAULT_MODEL_NAME, compute_dtype="float16")
31
+ logger.info("Dia model loaded successfully")
32
+ except Exception as e:
33
+ logger.error(f"Error loading Dia model: {e}", exc_info=True)
34
+ raise
35
+ return _model
36
+
37
+
38
+ def generate_speech(text: str, language: str = "zh") -> str:
39
+ """Public interface for TTS generation using Dia model
40
+
41
+ Args:
42
+ text (str): Input text to synthesize
43
+ language (str): Language code (not used in Dia model, kept for API compatibility)
44
+
45
+ Returns:
46
+ str: Path to the generated audio file
47
+ """
48
+ logger.info(f"Generating speech for text length: {len(text)}")
49
+
50
+ try:
51
+ # Create output directory if it doesn't exist
52
+ os.makedirs("temp/outputs", exist_ok=True)
53
+
54
+ # Generate unique output path
55
+ output_path = f"temp/outputs/output_{int(time.time())}.wav"
56
+
57
+ # Get the model
58
+ model = _get_model()
59
+
60
+ # Generate audio
61
+ start_time = time.time()
62
+
63
+ with torch.inference_mode():
64
+ output_audio_np = model.generate(
65
+ text,
66
+ max_tokens=None, # Use default from model config
67
+ cfg_scale=3.0,
68
+ temperature=1.3,
69
+ top_p=0.95,
70
+ cfg_filter_top_k=35,
71
+ use_torch_compile=False, # Keep False for stability
72
+ verbose=False
73
+ )
74
+
75
+ end_time = time.time()
76
+ logger.info(f"Generation finished in {end_time - start_time:.2f} seconds")
77
+
78
+ # Process the output
79
+ if output_audio_np is not None:
80
+ # Apply a slight slowdown for better quality (0.94x speed)
81
+ speed_factor = 0.94
82
+ original_len = len(output_audio_np)
83
+ target_len = int(original_len / speed_factor)
84
+
85
+ if target_len != original_len and target_len > 0:
86
+ x_original = np.arange(original_len)
87
+ x_resampled = np.linspace(0, original_len - 1, target_len)
88
+ output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
89
+ logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
90
+
91
+ # Save the audio file
92
+ sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
93
+ logger.info(f"Audio saved to {output_path}")
94
+
95
+ return output_path
96
+ else:
97
+ logger.warning("Generation produced no output, returning dummy audio")
98
+ return "temp/outputs/dummy.wav"
99
+
100
+ except Exception as e:
101
+ logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
102
+ # Return dummy path in case of error
103
+ return "temp/outputs/dummy.wav"