Michael Hu commited on
Commit
2d176f4
·
1 Parent(s): 501f0b5

fix runtime dia model issue

Browse files
Files changed (2) hide show
  1. utils/tts_dia.py +56 -33
  2. utils/tts_dummy.py +23 -1
utils/tts_dia.py CHANGED
@@ -44,7 +44,11 @@ def _get_model() -> Dia:
44
  # Log model details
45
  logger.info(f"Dia model loaded successfully")
46
  logger.info(f"Model type: {type(_model).__name__}")
47
- logger.info(f"Model device: {next(_model.parameters()).device}")
 
 
 
 
48
  except ImportError as import_err:
49
  logger.error(f"Import error loading Dia model: {import_err}")
50
  logger.error(f"This may indicate missing dependencies")
@@ -74,34 +78,42 @@ def generate_speech(text: str, language: str = "zh") -> str:
74
  logger.info(f"Generating speech for text length: {len(text)}")
75
  logger.info(f"Text content (first 50 chars): {text[:50]}...")
76
 
 
 
 
77
  try:
78
- # Create output directory if it doesn't exist
79
- output_dir = "temp/outputs"
80
- logger.info(f"Ensuring output directory exists: {output_dir}")
81
- try:
82
- os.makedirs(output_dir, exist_ok=True)
83
- logger.info(f"Output directory ready: {output_dir}")
84
- except PermissionError as perm_err:
85
- logger.error(f"Permission error creating output directory: {perm_err}")
86
- raise
87
- except Exception as dir_err:
88
- logger.error(f"Error creating output directory: {dir_err}")
89
- raise
90
-
91
- # Generate unique output path
92
- timestamp = int(time.time())
93
- output_path = f"{output_dir}/output_{timestamp}.wav"
94
- logger.info(f"Output will be saved to: {output_path}")
95
-
96
- # Get the model
97
- logger.info("Retrieving Dia model instance")
98
- try:
99
- model = _get_model()
100
- logger.info("Successfully retrieved Dia model instance")
101
- except Exception as model_err:
102
- logger.error(f"Failed to get Dia model: {model_err}")
103
- logger.error(f"Error type: {type(model_err).__name__}")
104
- raise
 
 
 
 
 
105
 
106
  # Generate audio
107
  logger.info("Starting audio generation with Dia model")
@@ -125,11 +137,17 @@ def generate_speech(text: str, language: str = "zh") -> str:
125
  logger.error(f"Runtime error during generation: {rt_err}")
126
  if "CUDA out of memory" in str(rt_err):
127
  logger.error("CUDA out of memory error - consider reducing batch size or model size")
128
- raise
 
 
 
129
  except Exception as gen_err:
130
  logger.error(f"Error during audio generation: {gen_err}")
131
  logger.error(f"Error type: {type(gen_err).__name__}")
132
- raise
 
 
 
133
 
134
  end_time = time.time()
135
  generation_time = end_time - start_time
@@ -164,7 +182,10 @@ def generate_speech(text: str, language: str = "zh") -> str:
164
  except Exception as save_err:
165
  logger.error(f"Error saving audio file: {save_err}")
166
  logger.error(f"Error type: {type(save_err).__name__}")
167
- raise
 
 
 
168
 
169
  return output_path
170
  else:
@@ -194,5 +215,7 @@ def generate_speech(text: str, language: str = "zh") -> str:
194
  elif isinstance(e, FileNotFoundError):
195
  logger.error(f"File not found - check if model files exist and are accessible")
196
 
197
- # Return dummy path in case of error
198
- return "temp/outputs/dummy.wav"
 
 
 
44
  # Log model details
45
  logger.info(f"Dia model loaded successfully")
46
  logger.info(f"Model type: {type(_model).__name__}")
47
+ # Check if model has parameters method (PyTorch models do, but Dia might not)
48
+ if hasattr(_model, 'parameters'):
49
+ logger.info(f"Model device: {next(_model.parameters()).device}")
50
+ else:
51
+ logger.info("Model device: Device information not available for Dia model")
52
  except ImportError as import_err:
53
  logger.error(f"Import error loading Dia model: {import_err}")
54
  logger.error(f"This may indicate missing dependencies")
 
78
  logger.info(f"Generating speech for text length: {len(text)}")
79
  logger.info(f"Text content (first 50 chars): {text[:50]}...")
80
 
81
+ # Create output directory if it doesn't exist
82
+ output_dir = "temp/outputs"
83
+ logger.info(f"Ensuring output directory exists: {output_dir}")
84
  try:
85
+ os.makedirs(output_dir, exist_ok=True)
86
+ logger.info(f"Output directory ready: {output_dir}")
87
+ except PermissionError as perm_err:
88
+ logger.error(f"Permission error creating output directory: {perm_err}")
89
+ # Fall back to dummy TTS
90
+ logger.info("Falling back to dummy TTS due to directory creation error")
91
+ from utils.tts_dummy import generate_speech as dummy_generate_speech
92
+ return dummy_generate_speech(text, language)
93
+ except Exception as dir_err:
94
+ logger.error(f"Error creating output directory: {dir_err}")
95
+ # Fall back to dummy TTS
96
+ logger.info("Falling back to dummy TTS due to directory creation error")
97
+ from utils.tts_dummy import generate_speech as dummy_generate_speech
98
+ return dummy_generate_speech(text, language)
99
+
100
+ # Generate unique output path
101
+ timestamp = int(time.time())
102
+ output_path = f"{output_dir}/output_{timestamp}.wav"
103
+ logger.info(f"Output will be saved to: {output_path}")
104
+
105
+ # Get the model
106
+ logger.info("Retrieving Dia model instance")
107
+ try:
108
+ model = _get_model()
109
+ logger.info("Successfully retrieved Dia model instance")
110
+ except Exception as model_err:
111
+ logger.error(f"Failed to get Dia model: {model_err}")
112
+ logger.error(f"Error type: {type(model_err).__name__}")
113
+ # Fall back to dummy TTS
114
+ logger.info("Falling back to dummy TTS due to model loading error")
115
+ from utils.tts_dummy import generate_speech as dummy_generate_speech
116
+ return dummy_generate_speech(text, language)
117
 
118
  # Generate audio
119
  logger.info("Starting audio generation with Dia model")
 
137
  logger.error(f"Runtime error during generation: {rt_err}")
138
  if "CUDA out of memory" in str(rt_err):
139
  logger.error("CUDA out of memory error - consider reducing batch size or model size")
140
+ # Fall back to dummy TTS
141
+ logger.info("Falling back to dummy TTS due to runtime error during generation")
142
+ from utils.tts_dummy import generate_speech as dummy_generate_speech
143
+ return dummy_generate_speech(text, language)
144
  except Exception as gen_err:
145
  logger.error(f"Error during audio generation: {gen_err}")
146
  logger.error(f"Error type: {type(gen_err).__name__}")
147
+ # Fall back to dummy TTS
148
+ logger.info("Falling back to dummy TTS due to error during generation")
149
+ from utils.tts_dummy import generate_speech as dummy_generate_speech
150
+ return dummy_generate_speech(text, language)
151
 
152
  end_time = time.time()
153
  generation_time = end_time - start_time
 
182
  except Exception as save_err:
183
  logger.error(f"Error saving audio file: {save_err}")
184
  logger.error(f"Error type: {type(save_err).__name__}")
185
+ # Fall back to dummy TTS
186
+ logger.info("Falling back to dummy TTS due to error saving audio file")
187
+ from utils.tts_dummy import generate_speech as dummy_generate_speech
188
+ return dummy_generate_speech(text, language)
189
 
190
  return output_path
191
  else:
 
215
  elif isinstance(e, FileNotFoundError):
216
  logger.error(f"File not found - check if model files exist and are accessible")
217
 
218
+ # Fall back to dummy TTS
219
+ logger.info("Falling back to dummy TTS due to unhandled exception")
220
+ from utils.tts_dummy import generate_speech as dummy_generate_speech
221
+ return dummy_generate_speech(text, language)
utils/tts_dummy.py CHANGED
@@ -1,3 +1,25 @@
1
  def generate_speech(text: str, language: str = "zh") -> str:
2
  """Public interface for TTS generation"""
3
- return "temp/outputs/dummy.wav"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  def generate_speech(text: str, language: str = "zh") -> str:
2
  """Public interface for TTS generation"""
3
+ import os
4
+ import numpy as np
5
+ import soundfile as sf
6
+ import time
7
+
8
+ # Create output directory if it doesn't exist
9
+ output_dir = "temp/outputs"
10
+ os.makedirs(output_dir, exist_ok=True)
11
+
12
+ # Generate a unique filename
13
+ timestamp = int(time.time())
14
+ output_path = f"{output_dir}/dummy_{timestamp}.wav"
15
+
16
+ # Generate a simple sine wave as dummy audio
17
+ sample_rate = 24000
18
+ duration = 2.0 # seconds
19
+ t = np.linspace(0, duration, int(sample_rate * duration), False)
20
+ tone = np.sin(2 * np.pi * 440 * t) * 0.3
21
+
22
+ # Save the audio file
23
+ sf.write(output_path, tone, sample_rate)
24
+
25
+ return output_path