Michael Hu commited on
Commit
f0248ed
Β·
1 Parent(s): 4a9bb1a

move to Gradio so we can leverage ZeroGPU

Browse files
Files changed (4) hide show
  1. app.py +172 -128
  2. app_gradio.py +237 -0
  3. requirements.txt +5 -2
  4. utils/tts_dia.py +2 -1
app.py CHANGED
@@ -1,5 +1,4 @@
1
- """
2
- Main entry point for the Audio Translation Web Application
3
  Handles file upload, processing pipeline, and UI rendering
4
  """
5
 
@@ -14,180 +13,225 @@ logging.basicConfig(
14
  )
15
  logger = logging.getLogger(__name__)
16
 
17
- import streamlit as st
18
  import os
19
  import time
20
- import subprocess
 
21
  from utils.stt import transcribe_audio
22
  from utils.translation import translate_text
23
- from utils.tts import get_tts_engine, generate_speech
24
 
25
  # Initialize environment configurations
26
  os.makedirs("temp/uploads", exist_ok=True)
27
  os.makedirs("temp/outputs", exist_ok=True)
28
 
29
- def configure_page():
30
- """Set up Streamlit page configuration"""
31
- logger.info("Configuring Streamlit page")
32
- st.set_page_config(
33
- page_title="Audio Translator",
34
- page_icon="🎧",
35
- layout="wide",
36
- initial_sidebar_state="expanded"
37
- )
38
- st.markdown("""
39
- <style>
40
- .reportview-container {margin-top: -2em;}
41
- #MainMenu {visibility: hidden;}
42
- .stDeployButton {display:none;}
43
- .stAlert {padding: 20px !important;}
44
- </style>
45
- """, unsafe_allow_html=True)
46
 
47
- def handle_file_processing(upload_path):
48
  """
49
  Execute the complete processing pipeline:
50
  1. Speech-to-Text (STT)
51
  2. Machine Translation
52
  3. Text-to-Speech (TTS)
 
 
 
 
 
 
53
  """
54
- logger.info(f"Starting processing for: {upload_path}")
55
- progress_bar = st.progress(0)
56
- status_text = st.empty()
57
 
58
  try:
 
 
 
 
 
 
59
  # STT Phase
60
  logger.info("Beginning STT processing")
61
- status_text.markdown("πŸ” **Performing Speech Recognition...**")
62
- with st.spinner("Initializing Whisper model..."):
63
- english_text = transcribe_audio(upload_path)
64
- progress_bar.progress(30)
65
  logger.info(f"STT completed. Text length: {len(english_text)} characters")
66
 
67
  # Translation Phase
68
  logger.info("Beginning translation")
69
- status_text.markdown("🌐 **Translating Content...**")
70
- with st.spinner("Loading translation model..."):
71
- chinese_text = translate_text(english_text)
72
- progress_bar.progress(60)
73
  logger.info(f"Translation completed. Translated length: {len(chinese_text)} characters")
74
 
75
  # TTS Phase
76
  logger.info("Beginning TTS generation")
77
- status_text.markdown("🎡 **Generating Chinese Speech...**")
78
 
79
  # Initialize TTS engine with appropriate language code for Chinese
80
  engine = get_tts_engine(lang_code='z') # 'z' for Mandarin Chinese
81
 
82
  # Generate speech and get the file path
83
  output_path = engine.generate_speech(chinese_text, voice="zf_xiaobei")
84
- progress_bar.progress(100)
85
  logger.info(f"TTS completed. Output file: {output_path}")
86
 
87
- # Store the text for streaming playback
88
- st.session_state.current_text = chinese_text
89
 
90
- status_text.success("βœ… Processing Complete!")
91
- return english_text, chinese_text, output_path
92
 
93
  except Exception as e:
94
  logger.error(f"Processing failed: {str(e)}", exc_info=True)
95
- status_text.error(f"❌ Processing Failed: {str(e)}")
96
- st.exception(e)
97
- raise
98
 
99
- def render_results(english_text, chinese_text, output_path):
100
- """Display processing results in organized columns"""
101
- logger.info("Rendering results")
102
- st.divider()
103
 
104
- col1, col2 = st.columns([2, 1])
105
- with col1:
106
- st.subheader("Recognition Results")
107
- st.code(english_text, language="text")
108
 
109
- st.subheader("Translation Results")
110
- st.code(chinese_text, language="text")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
- with col2:
113
- st.subheader("Audio Output")
114
- # Standard audio player for the full file
115
- st.audio(output_path)
116
-
117
- # Download button
118
- with open(output_path, "rb") as f:
119
- st.download_button(
120
- label="Download Audio",
121
- data=f,
122
- file_name="translated_audio.wav",
123
- mime="audio/wav"
124
- )
125
-
126
- # Streaming playback controls
127
- st.subheader("Streaming Playback")
128
- if st.button("Stream Audio"):
129
- engine = get_tts_engine(lang_code='z')
130
- streaming_placeholder = st.empty()
131
-
132
- # Stream the audio in chunks
133
- for sample_rate, audio_chunk in engine.generate_speech_stream(
134
- chinese_text,
135
- voice="zf_xiaobei"
136
- ):
137
- # Create a temporary file for each chunk
138
- temp_chunk_path = f"temp/outputs/chunk_{time.time()}.wav"
139
- import soundfile as sf
140
- sf.write(temp_chunk_path, audio_chunk, sample_rate)
141
 
142
- # Play the chunk
143
- with streaming_placeholder:
144
- st.audio(temp_chunk_path, sample_rate=sample_rate)
145
 
146
- # Clean up the temporary chunk file
147
- os.remove(temp_chunk_path)
148
-
149
- def initialize_session_state():
150
- """Initialize session state variables"""
151
- if 'current_text' not in st.session_state:
152
- st.session_state.current_text = None
153
-
154
- def main():
155
- """Main application workflow"""
156
- logger.info("Starting application")
157
- configure_page()
158
- initialize_session_state()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
- st.title("🎧 High-Quality Audio Translation System")
161
- st.markdown("Upload English Audio β†’ Get Chinese Speech Output")
162
-
163
- # Voice selection in sidebar
164
- st.sidebar.header("TTS Settings")
165
- voice_options = {
166
- "Xiaobei (Female)": "zf_xiaobei",
167
- "Yunjian (Male)": "zm_yunjian",
168
- }
169
- selected_voice = st.sidebar.selectbox(
170
- "Select Voice",
171
- list(voice_options.keys()),
172
- format_func=lambda x: x
173
- )
174
- speed = st.sidebar.slider("Speech Speed", 0.5, 2.0, 1.0, 0.1)
175
 
176
- uploaded_file = st.file_uploader(
177
- "Select Audio File (MP3/WAV)",
178
- type=["mp3", "wav"],
179
- accept_multiple_files=False
180
- )
181
-
182
- if uploaded_file:
183
- logger.info(f"File uploaded: {uploaded_file.name}")
184
- upload_path = os.path.join("temp/uploads", uploaded_file.name)
185
- with open(upload_path, "wb") as f:
186
- f.write(uploaded_file.getbuffer())
187
-
188
- results = handle_file_processing(upload_path)
189
- if results:
190
- render_results(*results)
191
 
192
  if __name__ == "__main__":
193
  main()
 
1
+ """Main entry point for the Audio Translation Web Application using Gradio
 
2
  Handles file upload, processing pipeline, and UI rendering
3
  """
4
 
 
13
  )
14
  logger = logging.getLogger(__name__)
15
 
16
+ import gradio as gr
17
  import os
18
  import time
19
+ import numpy as np
20
+ import soundfile as sf
21
  from utils.stt import transcribe_audio
22
  from utils.translation import translate_text
23
+ from utils.tts import get_tts_engine
24
 
25
  # Initialize environment configurations
26
  os.makedirs("temp/uploads", exist_ok=True)
27
  os.makedirs("temp/outputs", exist_ok=True)
28
 
29
+ # CSS for styling the Gradio interface
30
+ css = """
31
+ .gradio-container {
32
+ max-width: 1200px;
33
+ margin: 0 auto;
34
+ }
35
+
36
+ .output-text {
37
+ font-family: monospace;
38
+ padding: 10px;
39
+ background-color: #f5f5f5;
40
+ border-radius: 4px;
41
+ }
42
+ """
 
 
 
43
 
44
+ def handle_file_processing(audio_file):
45
  """
46
  Execute the complete processing pipeline:
47
  1. Speech-to-Text (STT)
48
  2. Machine Translation
49
  3. Text-to-Speech (TTS)
50
+
51
+ Args:
52
+ audio_file: Tuple containing (sample_rate, audio_data)
53
+
54
+ Returns:
55
+ Tuple containing (english_text, chinese_text, output_audio)
56
  """
57
+ logger.info("Starting processing for uploaded audio")
 
 
58
 
59
  try:
60
+ # Save the uploaded audio to a temporary file
61
+ sr, audio_data = audio_file
62
+ temp_path = os.path.join("temp/uploads", f"upload_{time.time()}.wav")
63
+ sf.write(temp_path, audio_data, sr)
64
+ logger.info(f"Saved uploaded audio to {temp_path}")
65
+
66
  # STT Phase
67
  logger.info("Beginning STT processing")
68
+ english_text = transcribe_audio(temp_path)
 
 
 
69
  logger.info(f"STT completed. Text length: {len(english_text)} characters")
70
 
71
  # Translation Phase
72
  logger.info("Beginning translation")
73
+ chinese_text = translate_text(english_text)
 
 
 
74
  logger.info(f"Translation completed. Translated length: {len(chinese_text)} characters")
75
 
76
  # TTS Phase
77
  logger.info("Beginning TTS generation")
 
78
 
79
  # Initialize TTS engine with appropriate language code for Chinese
80
  engine = get_tts_engine(lang_code='z') # 'z' for Mandarin Chinese
81
 
82
  # Generate speech and get the file path
83
  output_path = engine.generate_speech(chinese_text, voice="zf_xiaobei")
 
84
  logger.info(f"TTS completed. Output file: {output_path}")
85
 
86
+ # Load the generated audio for Gradio output
87
+ audio_data, sr = sf.read(output_path)
88
 
89
+ return english_text, chinese_text, (sr, audio_data)
 
90
 
91
  except Exception as e:
92
  logger.error(f"Processing failed: {str(e)}", exc_info=True)
93
+ raise gr.Error(f"Processing Failed: {str(e)}")
 
 
94
 
95
+ def stream_audio(chinese_text, voice, speed):
96
+ """
97
+ Stream audio in chunks for the Gradio interface
 
98
 
99
+ Args:
100
+ chinese_text: The Chinese text to convert to speech
101
+ voice: The voice to use
102
+ speed: The speech speed factor
103
 
104
+ Returns:
105
+ Generator yielding audio chunks
106
+ """
107
+ engine = get_tts_engine(lang_code='z')
108
+
109
+ # Stream the audio in chunks
110
+ for sample_rate, audio_chunk in engine.generate_speech_stream(
111
+ chinese_text,
112
+ voice=voice,
113
+ speed=speed
114
+ ):
115
+ # Create a temporary file for each chunk
116
+ temp_chunk_path = f"temp/outputs/chunk_{time.time()}.wav"
117
+ sf.write(temp_chunk_path, audio_chunk, sample_rate)
118
+
119
+ # Load the chunk for Gradio output
120
+ chunk_data, sr = sf.read(temp_chunk_path)
121
+
122
+ # Clean up the temporary chunk file
123
+ os.remove(temp_chunk_path)
124
+
125
+ yield (sr, chunk_data)
126
 
127
+ def create_interface():
128
+ """
129
+ Create and configure the Gradio interface
130
+
131
+ Returns:
132
+ Gradio Blocks interface
133
+ """
134
+ with gr.Blocks(css=css) as interface:
135
+ gr.Markdown("# 🎧 High-Quality Audio Translation System")
136
+ gr.Markdown("Upload English Audio β†’ Get Chinese Speech Output")
137
+
138
+ with gr.Row():
139
+ with gr.Column(scale=2):
140
+ # File upload component
141
+ audio_input = gr.Audio(
142
+ label="Upload English Audio",
143
+ type="numpy",
144
+ sources=["upload", "microphone"]
145
+ )
 
 
 
 
 
 
 
 
 
 
146
 
147
+ # Process button
148
+ process_btn = gr.Button("Process Audio", variant="primary")
 
149
 
150
+ with gr.Column(scale=1):
151
+ # TTS Settings
152
+ with gr.Box():
153
+ gr.Markdown("### TTS Settings")
154
+ voice_dropdown = gr.Dropdown(
155
+ choices=["Xiaobei (Female)", "Yunjian (Male)"],
156
+ value="Xiaobei (Female)",
157
+ label="Select Voice"
158
+ )
159
+ speed_slider = gr.Slider(
160
+ minimum=0.5,
161
+ maximum=2.0,
162
+ value=1.0,
163
+ step=0.1,
164
+ label="Speech Speed"
165
+ )
166
+
167
+ # Output section
168
+ with gr.Row():
169
+ with gr.Column(scale=2):
170
+ # Text outputs
171
+ english_output = gr.Textbox(
172
+ label="Recognition Results",
173
+ lines=5,
174
+ elem_classes=["output-text"]
175
+ )
176
+
177
+ chinese_output = gr.Textbox(
178
+ label="Translation Results",
179
+ lines=5,
180
+ elem_classes=["output-text"]
181
+ )
182
+
183
+ with gr.Column(scale=1):
184
+ # Audio output
185
+ audio_output = gr.Audio(
186
+ label="Audio Output",
187
+ type="numpy"
188
+ )
189
+
190
+ # Stream button
191
+ stream_btn = gr.Button("Stream Audio")
192
+
193
+ # Download button is automatically provided by gr.Audio
194
+
195
+ # Set up event handlers
196
+ process_btn.click(
197
+ fn=handle_file_processing,
198
+ inputs=[audio_input],
199
+ outputs=[english_output, chinese_output, audio_output]
200
+ )
201
+
202
+ # Map voice selection to actual voice IDs
203
+ def get_voice_id(voice_name):
204
+ voice_map = {
205
+ "Xiaobei (Female)": "zf_xiaobei",
206
+ "Yunjian (Male)": "zm_yunjian"
207
+ }
208
+ return voice_map.get(voice_name, "zf_xiaobei")
209
+
210
+ # Stream button handler
211
+ stream_btn.click(
212
+ fn=lambda text, voice, speed: stream_audio(text, get_voice_id(voice), speed),
213
+ inputs=[chinese_output, voice_dropdown, speed_slider],
214
+ outputs=audio_output
215
+ )
216
+
217
+ # Examples
218
+ gr.Examples(
219
+ examples=[
220
+ ["examples/sample1.mp3"],
221
+ ["examples/sample2.wav"]
222
+ ],
223
+ inputs=audio_input
224
+ )
225
 
226
+ return interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ def main():
229
+ """
230
+ Main application entry point
231
+ """
232
+ logger.info("Starting Gradio application")
233
+ interface = create_interface()
234
+ interface.launch()
 
 
 
 
 
 
 
 
235
 
236
  if __name__ == "__main__":
237
  main()
app_gradio.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Main entry point for the Audio Translation Web Application using Gradio
2
+ Handles file upload, processing pipeline, and UI rendering
3
+ """
4
+
5
+ import logging
6
+ logging.basicConfig(
7
+ level=logging.INFO,
8
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
9
+ handlers=[
10
+ logging.FileHandler("app.log"),
11
+ logging.StreamHandler()
12
+ ]
13
+ )
14
+ logger = logging.getLogger(__name__)
15
+
16
+ import gradio as gr
17
+ import os
18
+ import time
19
+ import numpy as np
20
+ import soundfile as sf
21
+ from utils.stt import transcribe_audio
22
+ from utils.translation import translate_text
23
+ from utils.tts import get_tts_engine, generate_speech
24
+
25
+ # Initialize environment configurations
26
+ os.makedirs("temp/uploads", exist_ok=True)
27
+ os.makedirs("temp/outputs", exist_ok=True)
28
+
29
+ # CSS for styling the Gradio interface
30
+ css = """
31
+ .gradio-container {
32
+ max-width: 1200px;
33
+ margin: 0 auto;
34
+ }
35
+
36
+ .output-text {
37
+ font-family: monospace;
38
+ padding: 10px;
39
+ background-color: #f5f5f5;
40
+ border-radius: 4px;
41
+ }
42
+ """
43
+
44
+ def handle_file_processing(audio_file):
45
+ """
46
+ Execute the complete processing pipeline:
47
+ 1. Speech-to-Text (STT)
48
+ 2. Machine Translation
49
+ 3. Text-to-Speech (TTS)
50
+
51
+ Args:
52
+ audio_file: Tuple containing (sample_rate, audio_data)
53
+
54
+ Returns:
55
+ Tuple containing (english_text, chinese_text, output_audio)
56
+ """
57
+ logger.info("Starting processing for uploaded audio")
58
+
59
+ try:
60
+ # Save the uploaded audio to a temporary file
61
+ sr, audio_data = audio_file
62
+ temp_path = os.path.join("temp/uploads", f"upload_{time.time()}.wav")
63
+ sf.write(temp_path, audio_data, sr)
64
+ logger.info(f"Saved uploaded audio to {temp_path}")
65
+
66
+ # STT Phase
67
+ logger.info("Beginning STT processing")
68
+ english_text = transcribe_audio(temp_path)
69
+ logger.info(f"STT completed. Text length: {len(english_text)} characters")
70
+
71
+ # Translation Phase
72
+ logger.info("Beginning translation")
73
+ chinese_text = translate_text(english_text)
74
+ logger.info(f"Translation completed. Translated length: {len(chinese_text)} characters")
75
+
76
+ # TTS Phase
77
+ logger.info("Beginning TTS generation")
78
+
79
+ # Initialize TTS engine with appropriate language code for Chinese
80
+ engine = get_tts_engine(lang_code='z') # 'z' for Mandarin Chinese
81
+
82
+ # Generate speech and get the file path
83
+ output_path = engine.generate_speech(chinese_text, voice="zf_xiaobei")
84
+ logger.info(f"TTS completed. Output file: {output_path}")
85
+
86
+ # Load the generated audio for Gradio output
87
+ audio_data, sr = sf.read(output_path)
88
+
89
+ return english_text, chinese_text, (sr, audio_data)
90
+
91
+ except Exception as e:
92
+ logger.error(f"Processing failed: {str(e)}", exc_info=True)
93
+ raise gr.Error(f"Processing Failed: {str(e)}")
94
+
95
+ def stream_audio(chinese_text, voice, speed):
96
+ """
97
+ Stream audio in chunks for the Gradio interface
98
+
99
+ Args:
100
+ chinese_text: The Chinese text to convert to speech
101
+ voice: The voice to use
102
+ speed: The speech speed factor
103
+
104
+ Returns:
105
+ Generator yielding audio chunks
106
+ """
107
+ engine = get_tts_engine(lang_code='z')
108
+
109
+ # Stream the audio in chunks
110
+ for sample_rate, audio_chunk in engine.generate_speech_stream(
111
+ chinese_text,
112
+ voice=voice,
113
+ speed=speed
114
+ ):
115
+ # Create a temporary file for each chunk
116
+ temp_chunk_path = f"temp/outputs/chunk_{time.time()}.wav"
117
+ sf.write(temp_chunk_path, audio_chunk, sample_rate)
118
+
119
+ # Load the chunk for Gradio output
120
+ chunk_data, sr = sf.read(temp_chunk_path)
121
+
122
+ # Clean up the temporary chunk file
123
+ os.remove(temp_chunk_path)
124
+
125
+ yield (sr, chunk_data)
126
+
127
+ def create_interface():
128
+ """
129
+ Create and configure the Gradio interface
130
+
131
+ Returns:
132
+ Gradio Blocks interface
133
+ """
134
+ with gr.Blocks(css=css) as interface:
135
+ gr.Markdown("# 🎧 High-Quality Audio Translation System")
136
+ gr.Markdown("Upload English Audio β†’ Get Chinese Speech Output")
137
+
138
+ with gr.Row():
139
+ with gr.Column(scale=2):
140
+ # File upload component
141
+ audio_input = gr.Audio(
142
+ label="Upload English Audio",
143
+ type="numpy",
144
+ sources=["upload", "microphone"]
145
+ )
146
+
147
+ # Process button
148
+ process_btn = gr.Button("Process Audio", variant="primary")
149
+
150
+ with gr.Column(scale=1):
151
+ # TTS Settings
152
+ with gr.Box():
153
+ gr.Markdown("### TTS Settings")
154
+ voice_dropdown = gr.Dropdown(
155
+ choices=["Xiaobei (Female)", "Yunjian (Male)"],
156
+ value="Xiaobei (Female)",
157
+ label="Select Voice"
158
+ )
159
+ speed_slider = gr.Slider(
160
+ minimum=0.5,
161
+ maximum=2.0,
162
+ value=1.0,
163
+ step=0.1,
164
+ label="Speech Speed"
165
+ )
166
+
167
+ # Output section
168
+ with gr.Row():
169
+ with gr.Column(scale=2):
170
+ # Text outputs
171
+ english_output = gr.Textbox(
172
+ label="Recognition Results",
173
+ lines=5,
174
+ elem_classes=["output-text"]
175
+ )
176
+
177
+ chinese_output = gr.Textbox(
178
+ label="Translation Results",
179
+ lines=5,
180
+ elem_classes=["output-text"]
181
+ )
182
+
183
+ with gr.Column(scale=1):
184
+ # Audio output
185
+ audio_output = gr.Audio(
186
+ label="Audio Output",
187
+ type="numpy"
188
+ )
189
+
190
+ # Stream button
191
+ stream_btn = gr.Button("Stream Audio")
192
+
193
+ # Download button is automatically provided by gr.Audio
194
+
195
+ # Set up event handlers
196
+ process_btn.click(
197
+ fn=handle_file_processing,
198
+ inputs=[audio_input],
199
+ outputs=[english_output, chinese_output, audio_output]
200
+ )
201
+
202
+ # Map voice selection to actual voice IDs
203
+ def get_voice_id(voice_name):
204
+ voice_map = {
205
+ "Xiaobei (Female)": "zf_xiaobei",
206
+ "Yunjian (Male)": "zm_yunjian"
207
+ }
208
+ return voice_map.get(voice_name, "zf_xiaobei")
209
+
210
+ # Stream button handler
211
+ stream_btn.click(
212
+ fn=lambda text, voice, speed: stream_audio(text, get_voice_id(voice), speed),
213
+ inputs=[chinese_output, voice_dropdown, speed_slider],
214
+ outputs=audio_output
215
+ )
216
+
217
+ # Examples
218
+ gr.Examples(
219
+ examples=[
220
+ ["examples/sample1.mp3"],
221
+ ["examples/sample2.wav"]
222
+ ],
223
+ inputs=audio_input
224
+ )
225
+
226
+ return interface
227
+
228
+ def main():
229
+ """
230
+ Main application entry point
231
+ """
232
+ logger.info("Starting Gradio application")
233
+ interface = create_interface()
234
+ interface.launch()
235
+
236
+ if __name__ == "__main__":
237
+ main()
requirements.txt CHANGED
@@ -8,8 +8,11 @@ torchaudio>=2.1.0
8
  scipy>=1.11
9
  munch>=2.5
10
  accelerate>=1.2.0
11
- soundfile>=0.13.0
12
  kokoro>=0.9.4
13
  ordered-set>=4.1.0
14
  phonemizer-fork>=3.3.2
15
- descript-audio-codec
 
 
 
 
8
  scipy>=1.11
9
  munch>=2.5
10
  accelerate>=1.2.0
11
+ soundfile>=0.13.1
12
  kokoro>=0.9.4
13
  ordered-set>=4.1.0
14
  phonemizer-fork>=3.3.2
15
+ descript-audio-codec
16
+ gradio>=5.25.2
17
+ gradio-dialogue>=0.0.4
18
+ huggingface-hub>=0.30.2
utils/tts_dia.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  import soundfile as sf
7
  from pathlib import Path
8
  from typing import Optional
 
9
 
10
  from dia.model import Dia
11
 
@@ -64,7 +65,7 @@ def _get_model() -> Dia:
64
  raise
65
  return _model
66
 
67
-
68
  def generate_speech(text: str, language: str = "zh") -> str:
69
  """Public interface for TTS generation using Dia model
70
 
 
6
  import soundfile as sf
7
  from pathlib import Path
8
  from typing import Optional
9
+ import spaces
10
 
11
  from dia.model import Dia
12
 
 
65
  raise
66
  return _model
67
 
68
+ @spaces.GPU
69
  def generate_speech(text: str, language: str = "zh") -> str:
70
  """Public interface for TTS generation using Dia model
71