RSHVR commited on
Commit
839f7b2
·
verified ·
1 Parent(s): 12d303c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +79 -173
app.py CHANGED
@@ -1,191 +1,97 @@
1
  import os
2
- import tempfile
3
  import gradio as gr
4
- import torch
5
- import torchaudio
6
- import spaces
7
- from huggingface_hub import snapshot_download
8
- from tortoise.api import TextToSpeech
9
- from tortoise.utils.audio import load_audio
10
- import numpy as np
11
- import uuid
12
- from pydub import AudioSegment
13
 
14
- # Create output directory if it doesn't exist
15
- os.makedirs("outputs", exist_ok=True)
 
 
16
 
17
- # Check for CUDA availability (this will show CPU due to Zero-GPU)
18
- device = "cuda" if torch.cuda.is_available() else "cpu"
19
- print(f"Initial device check: {device}")
20
 
21
- # Create a tensor to verify Zero-GPU is working
22
- zero = torch.Tensor([0])
23
- if torch.cuda.is_available():
24
- zero = zero.cuda()
25
- print(f"Zero tensor device: {zero.device}")
26
 
27
- # Initialize Tortoise TTS (will be loaded on demand with Zero-GPU)
28
- tts = None
29
-
30
- # Available preset voice options
31
- PRESET_VOICES = ["random", "angie", "daniel", "deniro", "emma", "freeman",
32
- "geralt", "halle", "jlaw", "lj", "mol", "myself", "pat",
33
- "snakes", "tim_reynolds", "tom", "train_atkins", "train_daws",
34
- "train_dotrice", "train_dreams", "train_empire", "train_grace",
35
- "train_kennard", "train_lescault", "train_mouse", "weaver", "william"]
36
-
37
- def process_audio_file(audio_file_path):
38
- """Process uploaded audio file to ensure it meets Tortoise requirements"""
39
- # Load audio file
40
- audio = AudioSegment.from_file(audio_file_path)
41
-
42
- # Convert to WAV format if it's not already
43
- if not audio_file_path.lower().endswith('.wav'):
44
- temp_wav = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
45
- audio.export(temp_wav.name, format="wav")
46
- audio_file_path = temp_wav.name
47
-
48
- # Resample to 22.05kHz which is what Tortoise expects
49
- y, sr = torchaudio.load(audio_file_path)
50
- if sr != 22050:
51
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=22050)
52
- y = resampler(y)
53
- temp_file = tempfile.NamedTemporaryFile(suffix='.wav', delete=False)
54
- torchaudio.save(temp_file.name, y, 22050)
55
- audio_file_path = temp_file.name
56
 
57
- return audio_file_path
58
-
59
- @spaces.GPU
60
- def generate_tts_with_voice(text, voice_sample_path=None, preset_voice=None):
61
- """Generate TTS audio using Tortoise with either a custom voice or preset"""
62
- global tts
63
 
64
- try:
65
- # Now that we're inside the @spaces.GPU decorated function, CUDA should be available
66
- print(f"GPU function device: {zero.device}")
67
-
68
- # Initialize TTS model if not already initialized
69
- if tts is None:
70
- tts = TextToSpeech(use_deepspeed=True if torch.cuda.is_available() else False)
71
- print("TTS model initialized")
72
-
73
- voice_samples = None
74
-
75
- if voice_sample_path:
76
- # Process the voice sample
77
- voice_sample_path = process_audio_file(voice_sample_path)
78
- voice_samples, _ = load_audio(voice_sample_path, 22050)
79
- voice_samples = [voice_samples]
80
- preset_voice = None
81
- elif preset_voice and preset_voice != "random":
82
- voice_samples = None
83
- else: # random voice
84
- voice_samples = None
85
- preset_voice = "random"
86
-
87
- # Generate the speech
88
- output_id = str(uuid.uuid4())[:8]
89
- output_path = f"outputs/tts_output_{output_id}.wav"
90
-
91
- gen = tts.tts_with_preset(
92
- text,
93
- voice_samples=voice_samples,
94
- preset=preset_voice
95
- )
96
-
97
- # Save the generated audio
98
- torchaudio.save(output_path, gen.squeeze(0).cpu(), 24000)
99
-
100
- return output_path, "Success: TTS generation completed."
101
- except Exception as e:
102
- return None, f"Error: {str(e)}"
103
-
104
- @spaces.GPU
105
- def tts_interface(text, audio_file, preset_voice, record_audio):
106
- """Interface function for Gradio with GPU acceleration"""
107
- print(f"Processing with device: {zero.device}")
108
 
109
- voice_sample_path = None
 
 
 
 
 
 
110
 
111
- # Determine which voice input to use
112
- if record_audio is not None:
113
- # Use recorded audio
114
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.wav')
115
- temp_file.close()
116
- record_audio = (record_audio[0], 22050) # Ensure sample rate is 22050
117
- torchaudio.save(temp_file.name, torch.tensor(record_audio[0]).unsqueeze(0), record_audio[1])
118
- voice_sample_path = temp_file.name
119
- elif audio_file is not None:
120
- # Use uploaded audio file
121
- voice_sample_path = audio_file
122
 
123
- # If no custom voice is provided, use the preset
124
- if voice_sample_path is None and preset_voice == "":
125
- preset_voice = "random"
 
 
126
 
127
- # Generate TTS
128
- output_path, message = generate_tts_with_voice(text, voice_sample_path, preset_voice)
129
 
130
- if output_path:
131
- return output_path, message
132
- else:
133
- return None, message
134
 
135
- # Create Gradio interface
136
- with gr.Blocks(title="Tortoise TTS with Voice Cloning") as demo:
137
- gr.Markdown("# Tortoise Text-to-Speech with Voice Cloning")
138
- gr.Markdown("Enter text and either upload a voice sample, record your voice, or select a preset voice.")
139
-
140
- with gr.Row():
141
- with gr.Column():
142
- text_input = gr.Textbox(
143
- label="Text to speak",
144
- placeholder="Enter the text you want to convert to speech...",
145
- lines=5
146
- )
147
- preset_voice = gr.Dropdown(
148
- choices=[""] + PRESET_VOICES,
149
- label="Preset Voice (optional)",
150
- value=""
151
- )
152
-
153
- with gr.Column():
154
- gr.Markdown("### Voice Input Options")
155
- with gr.Tab("Upload Voice"):
156
- audio_file = gr.Audio(
157
- label="Upload Voice Sample (optional)",
158
- type="filepath"
159
- )
160
- with gr.Tab("Record Voice"):
161
- record_audio = gr.Audio(
162
- label="Record Your Voice (optional)",
163
- source="microphone"
164
- )
165
-
166
- generate_button = gr.Button("Generate Speech")
167
-
168
- with gr.Row():
169
- output_audio = gr.Audio(label="Generated Speech")
170
- output_message = gr.Textbox(label="Status")
171
-
172
- generate_button.click(
173
- fn=tts_interface,
174
- inputs=[text_input, audio_file, preset_voice, record_audio],
175
- outputs=[output_audio, output_message]
176
  )
177
 
178
- gr.Markdown("### About This App")
179
- gr.Markdown("""
180
- This app uses Tortoise-TTS to generate high-quality speech from text.
181
-
182
- You can:
183
- - Enter any text you want to be spoken
184
- - Upload or record a voice sample for voice cloning
185
- - Or select from pre-defined voice presets
186
-
187
- The app runs on Hugging Face Spaces with Zero-GPU optimization.
188
- """)
189
 
 
190
  if __name__ == "__main__":
191
- demo.launch()
 
 
 
 
 
1
  import os
 
2
  import gradio as gr
3
+ from fastrtc import Stream, ReplyOnPause, AdditionalOutputs
 
 
 
 
 
 
 
 
4
 
5
+ # Import your modules
6
+ import stt
7
+ import tts
8
+ import cohereAPI
9
 
10
+ # Environment variables
11
+ COHERE_API_KEY = os.getenv("COHERE_API_KEY")
12
+ system_message = "You respond concisely, in about 15 words or less"
13
 
14
+ # Initialize conversation history
15
+ conversation_history = []
 
 
 
16
 
17
+ async def response(audio_file_path):
18
+ global conversation_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # Convert speech to text
21
+ user_message = await stt.transcribe_audio(audio_file_path)
 
 
 
 
22
 
23
+ # Add user message to chat history
24
+ yield AdditionalOutputs({"transcript": user_message, "role": "user"})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Send text to Cohere API
27
+ response_text, updated_history = await cohereAPI.send_message(
28
+ system_message,
29
+ user_message,
30
+ conversation_history,
31
+ COHERE_API_KEY
32
+ )
33
 
34
+ # Update conversation history
35
+ conversation_history = updated_history
 
 
 
 
 
 
 
 
 
36
 
37
+ # Generate speech from text
38
+ _, (sample_rate, speech_array) = await tts.generate_speech(
39
+ response_text,
40
+ voice_preset="random"
41
+ )
42
 
43
+ # Add assistant message to chat history
44
+ yield AdditionalOutputs({"transcript": response_text, "role": "assistant"})
45
 
46
+ # Return audio response
47
+ yield (sample_rate, speech_array)
 
 
48
 
49
+ # Create FastRTC stream with ReplyOnPause
50
+ stream = Stream(
51
+ handler=ReplyOnPause(response),
52
+ modality="audio",
53
+ mode="send-receive",
54
+ additional_outputs=[
55
+ {"name": "transcript", "type": "text"},
56
+ {"name": "role", "type": "text"}
57
+ ]
58
+ )
59
+
60
+ # Create Gradio interface that uses the FastRTC stream
61
+ with gr.Blocks(title="Voice Chat Assistant with ReplyOnPause") as demo:
62
+ gr.Markdown("# Voice Chat Assistant")
63
+ gr.Markdown("Speak and pause to trigger a response.")
64
+
65
+ chatbot = gr.Chatbot(label="Conversation")
66
+
67
+ # Mount the FastRTC UI
68
+ stream_ui = stream.ui(label="Speak")
69
+
70
+ # Handle additional outputs from FastRTC to update the chatbot
71
+ def update_chat(transcript, role, history):
72
+ if transcript and role:
73
+ if role == "user":
74
+ history.append((transcript, None))
75
+ elif role == "assistant":
76
+ if history and history[-1][1] is None:
77
+ history[-1] = (history[-1][0], transcript)
78
+ else:
79
+ history.append((None, transcript))
80
+ return history
81
+
82
+ stream_ui.change(
83
+ update_chat,
84
+ inputs=[stream_ui.output_components[0], stream_ui.output_components[1], chatbot],
85
+ outputs=[chatbot]
 
 
 
 
86
  )
87
 
88
+ clear_btn = gr.Button("Clear Conversation")
89
+ clear_btn.click(lambda: [], outputs=[chatbot])
 
 
 
 
 
 
 
 
 
90
 
91
+ # Launch the app
92
  if __name__ == "__main__":
93
+ demo.queue().launch(
94
+ server_name="0.0.0.0",
95
+ share=False,
96
+ show_error=True
97
+ )