RSHVR commited on
Commit
1e82508
·
verified ·
1 Parent(s): fd1adc1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +254 -0
app.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tempfile
3
+ import gradio as gr
4
+ import torch
5
+ import torchaudio
6
+ import spaces
7
+ from fastapi import FastAPI, File, UploadFile, Form
8
+ from fastapi.responses import FileResponse
9
+ from tortoise.api import TextToSpeech
10
+ from tortoise.utils.audio import load_audio
11
+ import numpy as np
12
+ import uvicorn
13
+ from typing import Optional
14
+ import uuid
15
+ from pydub import AudioSegment
16
+
17
+ # Create output directory if it doesn't exist
18
+ os.makedirs("outputs", exist_ok=True)
19
+
20
+ # Check for CUDA availability (this will show CPU due to Zero-GPU)
21
+ device = "cuda" if torch.cuda.is_available() else "cpu"
22
+ print(f"Initial device check: {device}")
23
+
24
+ # Create a tensor to verify Zero-GPU is working
25
+ zero = torch.Tensor([0])
26
+ if torch.cuda.is_available():
27
+ zero = zero.cuda()
28
+ print(f"Zero tensor device: {zero.device}")
29
+
30
+ # Initialize FastAPI
31
+ app = FastAPI(title="Tortoise TTS API")
32
+
33
+ # Initialize TTS (will be loaded on demand with Zero-GPU)
34
+ tts = None
35
+
36
+ # Available preset voice options
37
+ PRESET_VOICES = ["random", "angie", "daniel", "deniro", "emma", "freeman",
38
+ "geralt", "halle", "jlaw", "lj", "mol", "myself", "pat",
39
+ "snakes", "tim_reynolds", "tom", "train_atkins", "train_daws",
40
+ "train_dotrice", "train_dreams", "train_empire", "train_grace",
41
+ "train_kennard", "train_lescault", "train_mouse", "weaver", "william"]
42
+
43
+ def process_audio_file(audio_file_path):
44
+ """Process uploaded audio file to ensure it meets Tortoise requirements"""
45
+ # Load audio file
46
+ audio = AudioSegment.from_file(audio_file_path)
47
+
48
+ # Convert to WAV format if it's not already
49
+ if not audio_file_path.lower().endswith('.wav'):
50
+ temp_wav = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
51
+ audio.export(temp_wav.name, format="wav")
52
+ audio_file_path = temp_wav.name
53
+
54
+ # Resample to 22.05kHz which is what Tortoise expects
55
+ y, sr = torchaudio.load(audio_file_path)
56
+ if sr != 22050:
57
+ resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=22050)
58
+ y = resampler(y)
59
+ temp_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
60
+ torchaudio.save(temp_file.name, y, 22050)
61
+ audio_file_path = temp_file.name
62
+
63
+ return audio_file_path
64
+
65
+ @spaces.GPU
66
+ def generate_tts_with_voice(text, voice_sample_path=None, preset_voice=None):
67
+ """Generate TTS audio using Tortoise with either a custom voice or preset"""
68
+ global tts
69
+
70
+ try:
71
+ # Now that we're inside the @spaces.GPU decorated function, CUDA should be available
72
+ print(f"GPU function device: {zero.device}")
73
+
74
+ # Initialize TTS model if not already initialized
75
+ if tts is None:
76
+ tts = TextToSpeech(use_deepspeed=True if torch.cuda.is_available() else False)
77
+ print("TTS model initialized")
78
+
79
+ voice_samples = None
80
+
81
+ if voice_sample_path:
82
+ # Process the voice sample
83
+ voice_sample_path = process_audio_file(voice_sample_path)
84
+ voice_samples, _ = load_audio(voice_sample_path, 22050)
85
+ voice_samples = [voice_samples]
86
+ preset_voice = None
87
+ elif preset_voice and preset_voice != "random":
88
+ voice_samples = None
89
+ else: # random voice
90
+ voice_samples = None
91
+ preset_voice = "random"
92
+
93
+ # Generate the speech
94
+ output_id = str(uuid.uuid4())[:8]
95
+ output_path = f"outputs/tts_output_{output_id}.wav"
96
+
97
+ gen = tts.tts_with_preset(
98
+ text,
99
+ voice_samples=voice_samples,
100
+ preset=preset_voice
101
+ )
102
+
103
+ # Save the generated audio
104
+ torchaudio.save(output_path, gen.squeeze(0).cpu(), 24000)
105
+
106
+ return output_path, "Success: TTS generation completed."
107
+ except Exception as e:
108
+ return None, f"Error: {str(e)}"
109
+
110
+ @spaces.GPU
111
+ def tts_interface(text, audio_file, preset_voice, record_audio):
112
+ """Interface function for Gradio with GPU acceleration"""
113
+ print(f"Processing with device: {zero.device}")
114
+
115
+ voice_sample_path = None
116
+
117
+ # Determine which voice input to use
118
+ if record_audio is not None:
119
+ # Use recorded audio
120
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
121
+ temp_file.close()
122
+ record_audio = (record_audio[0], 22050) # Ensure sample rate is 22050
123
+ torchaudio.save(temp_file.name, torch.tensor(record_audio[0]).unsqueeze(0), record_audio[1])
124
+ voice_sample_path = temp_file.name
125
+ elif audio_file is not None:
126
+ # Use uploaded audio file
127
+ voice_sample_path = audio_file
128
+
129
+ # If no custom voice is provided, use the preset
130
+ if voice_sample_path is None and preset_voice == "":
131
+ preset_voice = "random"
132
+
133
+ # Generate TTS
134
+ output_path, message = generate_tts_with_voice(text, voice_sample_path, preset_voice)
135
+
136
+ if output_path:
137
+ return output_path, message
138
+ else:
139
+ return None, message
140
+
141
+ # FastAPI endpoints
142
+ @app.post("/api/tts_with_voice_file/")
143
+ @spaces.GPU
144
+ async def tts_with_voice_file(
145
+ text: str = Form(...),
146
+ voice_file: Optional[UploadFile] = File(None),
147
+ preset_voice: Optional[str] = Form("random")
148
+ ):
149
+ """API endpoint for TTS with an uploaded voice file"""
150
+ try:
151
+ print(f"Processing with device: {zero.device}")
152
+
153
+ voice_sample_path = None
154
+ if voice_file:
155
+ # Save uploaded file temporarily
156
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(voice_file.filename)[1])
157
+ temp_file.write(await voice_file.read())
158
+ temp_file.close()
159
+ voice_sample_path = temp_file.name
160
+
161
+ output_path, message = generate_tts_with_voice(text, voice_sample_path, preset_voice)
162
+
163
+ if output_path:
164
+ return FileResponse(output_path, media_type="audio/wav", filename="tts_output.wav")
165
+ else:
166
+ return {"status": "error", "message": message}
167
+ except Exception as e:
168
+ return {"status": "error", "message": f"Failed to process: {str(e)}"}
169
+
170
+ @app.post("/api/tts_with_preset/")
171
+ @spaces.GPU
172
+ async def tts_with_preset(
173
+ text: str = Form(...),
174
+ preset_voice: str = Form("random")
175
+ ):
176
+ """API endpoint for TTS with a preset voice"""
177
+ try:
178
+ print(f"Processing with device: {zero.device}")
179
+
180
+ output_path, message = generate_tts_with_voice(text, preset_voice=preset_voice)
181
+
182
+ if output_path:
183
+ return FileResponse(output_path, media_type="audio/wav", filename="tts_output.wav")
184
+ else:
185
+ return {"status": "error", "message": message}
186
+ except Exception as e:
187
+ return {"status": "error", "message": f"Failed to process: {str(e)}"}
188
+
189
+ # Create Gradio interface
190
+ with gr.Blocks(title="Tortoise TTS with Voice Cloning") as demo:
191
+ gr.Markdown("# Tortoise Text-to-Speech with Voice Cloning")
192
+ gr.Markdown("Enter text and either upload a voice sample, record your voice, or select a preset voice.")
193
+
194
+ with gr.Row():
195
+ with gr.Column():
196
+ text_input = gr.Textbox(
197
+ label="Text to speak",
198
+ placeholder="Enter the text you want to convert to speech...",
199
+ lines=5
200
+ )
201
+ preset_voice = gr.Dropdown(
202
+ choices=[""] + PRESET_VOICES,
203
+ label="Preset Voice (optional)",
204
+ value=""
205
+ )
206
+
207
+ with gr.Column():
208
+ gr.Markdown("### Voice Input Options")
209
+ with gr.Tab("Upload Voice"):
210
+ audio_file = gr.Audio(
211
+ label="Upload Voice Sample (optional)",
212
+ type="filepath"
213
+ )
214
+ with gr.Tab("Record Voice"):
215
+ record_audio = gr.Audio(
216
+ label="Record Your Voice (optional)",
217
+ source="microphone"
218
+ )
219
+
220
+ generate_button = gr.Button("Generate Speech")
221
+
222
+ with gr.Row():
223
+ output_audio = gr.Audio(label="Generated Speech")
224
+ output_message = gr.Textbox(label="Status")
225
+
226
+ generate_button.click(
227
+ fn=tts_interface,
228
+ inputs=[text_input, audio_file, preset_voice, record_audio],
229
+ outputs=[output_audio, output_message]
230
+ )
231
+
232
+ gr.Markdown("### API Endpoints")
233
+ gr.Markdown("""
234
+ This app also provides API endpoints:
235
+
236
+ 1. **Voice File TTS** - `/api/tts_with_voice_file/`
237
+ - POST request with:
238
+ - `text`: Text to convert to speech (required)
239
+ - `voice_file`: Audio file for voice cloning (optional)
240
+ - `preset_voice`: Name of preset voice (optional, defaults to "random")
241
+
242
+ 2. **Preset Voice TTS** - `/api/tts_with_preset/`
243
+ - POST request with:
244
+ - `text`: Text to convert to speech (required)
245
+ - `preset_voice`: Name of preset voice (required)
246
+
247
+ Both endpoints return a WAV file with the generated speech.
248
+ """)
249
+
250
+ # Mount the Gradio app to FastAPI
251
+ app = gr.mount_gradio_app(app, demo, path="/")
252
+
253
+ if __name__ == "__main__":
254
+ uvicorn.run(app, host="0.0.0.0", port=7860)