Michael Hu commited on
Commit
4a9bb1a
·
1 Parent(s): 2d176f4

fix runtime dia audio is not generated issue

Browse files
Files changed (1) hide show
  1. utils/tts_dia.py +73 -98
utils/tts_dia.py CHANGED
@@ -114,108 +114,83 @@ def generate_speech(text: str, language: str = "zh") -> str:
114
  logger.info("Falling back to dummy TTS due to model loading error")
115
  from utils.tts_dummy import generate_speech as dummy_generate_speech
116
  return dummy_generate_speech(text, language)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
 
118
- # Generate audio
119
- logger.info("Starting audio generation with Dia model")
120
- start_time = time.time()
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  try:
123
- with torch.inference_mode():
124
- logger.info("Calling model.generate() with inference_mode")
125
- output_audio_np = model.generate(
126
- text,
127
- max_tokens=None, # Use default from model config
128
- cfg_scale=3.0,
129
- temperature=1.3,
130
- top_p=0.95,
131
- cfg_filter_top_k=35,
132
- use_torch_compile=False, # Keep False for stability
133
- verbose=False
134
- )
135
- logger.info("Model.generate() completed")
136
- except RuntimeError as rt_err:
137
- logger.error(f"Runtime error during generation: {rt_err}")
138
- if "CUDA out of memory" in str(rt_err):
139
- logger.error("CUDA out of memory error - consider reducing batch size or model size")
140
  # Fall back to dummy TTS
141
- logger.info("Falling back to dummy TTS due to runtime error during generation")
142
- from utils.tts_dummy import generate_speech as dummy_generate_speech
143
- return dummy_generate_speech(text, language)
144
- except Exception as gen_err:
145
- logger.error(f"Error during audio generation: {gen_err}")
146
- logger.error(f"Error type: {type(gen_err).__name__}")
147
- # Fall back to dummy TTS
148
- logger.info("Falling back to dummy TTS due to error during generation")
149
  from utils.tts_dummy import generate_speech as dummy_generate_speech
150
  return dummy_generate_speech(text, language)
151
 
152
- end_time = time.time()
153
- generation_time = end_time - start_time
154
- logger.info(f"Generation finished in {generation_time:.2f} seconds")
155
-
156
- # Process the output
157
- if output_audio_np is not None:
158
- logger.info(f"Generated audio array shape: {output_audio_np.shape}, dtype: {output_audio_np.dtype}")
159
- logger.info(f"Audio stats - min: {output_audio_np.min():.4f}, max: {output_audio_np.max():.4f}, mean: {output_audio_np.mean():.4f}")
160
-
161
- # Apply a slight slowdown for better quality (0.94x speed)
162
- speed_factor = 0.94
163
- original_len = len(output_audio_np)
164
- target_len = int(original_len / speed_factor)
165
-
166
- logger.info(f"Applying speed adjustment factor: {speed_factor}")
167
- if target_len != original_len and target_len > 0:
168
- try:
169
- x_original = np.arange(original_len)
170
- x_resampled = np.linspace(0, original_len - 1, target_len)
171
- output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
172
- logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
173
- except Exception as resample_err:
174
- logger.error(f"Error during audio resampling: {resample_err}")
175
- logger.warning("Using original audio without resampling")
176
-
177
- # Save the audio file
178
- logger.info(f"Saving audio to file: {output_path}")
179
- try:
180
- sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
181
- logger.info(f"Audio successfully saved to {output_path}")
182
- except Exception as save_err:
183
- logger.error(f"Error saving audio file: {save_err}")
184
- logger.error(f"Error type: {type(save_err).__name__}")
185
- # Fall back to dummy TTS
186
- logger.info("Falling back to dummy TTS due to error saving audio file")
187
- from utils.tts_dummy import generate_speech as dummy_generate_speech
188
- return dummy_generate_speech(text, language)
189
-
190
- return output_path
191
- else:
192
- logger.warning("Generation produced no output (None returned from model)")
193
- logger.warning("This may indicate a model configuration issue or empty input text")
194
- dummy_path = f"{output_dir}/dummy_{timestamp}.wav"
195
- logger.warning(f"Returning dummy audio path: {dummy_path}")
196
- return dummy_path
197
-
198
- except Exception as e:
199
- logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
200
- logger.error(f"Error type: {type(e).__name__}")
201
-
202
- # Log additional diagnostic information based on error type
203
- if isinstance(e, ImportError):
204
- logger.error(f"Import error - missing dependency: {e.__class__.__module__}.{e.__class__.__name__}")
205
- logger.error("Check if all required packages are installed correctly")
206
- elif isinstance(e, RuntimeError) and "CUDA" in str(e):
207
- logger.error("CUDA-related runtime error - check GPU compatibility and memory")
208
- elif isinstance(e, AttributeError):
209
- logger.error(f"Attribute error - likely API incompatibility or incorrect module version")
210
- if hasattr(e, '__traceback__'):
211
- tb = e.__traceback__
212
- while tb.tb_next:
213
- tb = tb.tb_next
214
- logger.error(f"Error occurred in file: {tb.tb_frame.f_code.co_filename}, line {tb.tb_lineno}")
215
- elif isinstance(e, FileNotFoundError):
216
- logger.error(f"File not found - check if model files exist and are accessible")
217
-
218
- # Fall back to dummy TTS
219
- logger.info("Falling back to dummy TTS due to unhandled exception")
220
- from utils.tts_dummy import generate_speech as dummy_generate_speech
221
- return dummy_generate_speech(text, language)
 
114
  logger.info("Falling back to dummy TTS due to model loading error")
115
  from utils.tts_dummy import generate_speech as dummy_generate_speech
116
  return dummy_generate_speech(text, language)
117
+
118
+ # Generate audio
119
+ logger.info("Starting audio generation with Dia model")
120
+ start_time = time.time()
121
+
122
+ try:
123
+ with torch.inference_mode():
124
+ logger.info("Calling model.generate() with inference_mode")
125
+ output_audio_np = model.generate(
126
+ text,
127
+ max_tokens=None, # Use default from model config
128
+ cfg_scale=3.0,
129
+ temperature=1.3,
130
+ top_p=0.95,
131
+ cfg_filter_top_k=35,
132
+ use_torch_compile=False, # Keep False for stability
133
+ verbose=False
134
+ )
135
+ logger.info("Model.generate() completed")
136
+ except RuntimeError as rt_err:
137
+ logger.error(f"Runtime error during generation: {rt_err}")
138
+ if "CUDA out of memory" in str(rt_err):
139
+ logger.error("CUDA out of memory error - consider reducing batch size or model size")
140
+ # Fall back to dummy TTS
141
+ logger.info("Falling back to dummy TTS due to runtime error during generation")
142
+ from utils.tts_dummy import generate_speech as dummy_generate_speech
143
+ return dummy_generate_speech(text, language)
144
+ except Exception as gen_err:
145
+ logger.error(f"Error during audio generation: {gen_err}")
146
+ logger.error(f"Error type: {type(gen_err).__name__}")
147
+ # Fall back to dummy TTS
148
+ logger.info("Falling back to dummy TTS due to error during generation")
149
+ from utils.tts_dummy import generate_speech as dummy_generate_speech
150
+ return dummy_generate_speech(text, language)
151
+
152
+ end_time = time.time()
153
+ generation_time = end_time - start_time
154
+ logger.info(f"Generation finished in {generation_time:.2f} seconds")
155
+
156
+ # Process the output
157
+ if output_audio_np is not None:
158
+ logger.info(f"Generated audio array shape: {output_audio_np.shape}, dtype: {output_audio_np.dtype}")
159
+ logger.info(f"Audio stats - min: {output_audio_np.min():.4f}, max: {output_audio_np.max():.4f}, mean: {output_audio_np.mean():.4f}")
160
 
161
+ # Apply a slight slowdown for better quality (0.94x speed)
162
+ speed_factor = 0.94
163
+ original_len = len(output_audio_np)
164
+ target_len = int(original_len / speed_factor)
165
 
166
+ logger.info(f"Applying speed adjustment factor: {speed_factor}")
167
+ if target_len != original_len and target_len > 0:
168
+ try:
169
+ x_original = np.arange(original_len)
170
+ x_resampled = np.linspace(0, original_len - 1, target_len)
171
+ output_audio_np = np.interp(x_resampled, x_original, output_audio_np)
172
+ logger.info(f"Resampled audio from {original_len} to {target_len} samples for {speed_factor:.2f}x speed")
173
+ except Exception as resample_err:
174
+ logger.error(f"Error during audio resampling: {resample_err}")
175
+ logger.warning("Using original audio without resampling")
176
+
177
+ # Save the audio file
178
+ logger.info(f"Saving audio to file: {output_path}")
179
  try:
180
+ sf.write(output_path, output_audio_np, DEFAULT_SAMPLE_RATE)
181
+ logger.info(f"Audio successfully saved to {output_path}")
182
+ except Exception as save_err:
183
+ logger.error(f"Error saving audio file: {save_err}")
184
+ logger.error(f"Error type: {type(save_err).__name__}")
 
 
 
 
 
 
 
 
 
 
 
 
185
  # Fall back to dummy TTS
186
+ logger.info("Falling back to dummy TTS due to error saving audio file")
 
 
 
 
 
 
 
187
  from utils.tts_dummy import generate_speech as dummy_generate_speech
188
  return dummy_generate_speech(text, language)
189
 
190
+ return output_path
191
+ else:
192
+ logger.warning("Generation produced no output (None returned from model)")
193
+ logger.warning("This may indicate a model configuration issue or empty input text")
194
+ dummy_path = f"{output_dir}/dummy_{timestamp}.wav"
195
+ logger.warning(f"Returning dummy audio path: {dummy_path}")
196
+ return dummy_path