File size: 17,002 Bytes
6d75162
 
 
 
 
 
 
 
0f1518a
6d75162
 
 
 
633ab26
e02c9de
a6403d5
 
e02c9de
6d75162
e02c9de
6d75162
e02c9de
6d75162
 
 
e02c9de
6d75162
e02c9de
6d75162
e02c9de
6d75162
e02c9de
6d75162
 
e02c9de
6d75162
a6403d5
6d75162
e02c9de
633ab26
a6403d5
 
 
e02c9de
a6403d5
 
e02c9de
6d75162
 
e02c9de
633ab26
a6403d5
 
 
 
 
 
e02c9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6d75162
 
e02c9de
6d75162
 
e02c9de
6d75162
633ab26
 
6d75162
 
 
e02c9de
633ab26
6d75162
 
 
955241f
 
 
 
 
 
6d75162
 
 
 
 
e02c9de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
633ab26
 
e02c9de
 
 
 
633ab26
 
e02c9de
 
 
0f1518a
e02c9de
 
 
 
 
 
 
 
 
 
 
6d75162
e02c9de
633ab26
955241f
 
 
 
 
 
 
 
e02c9de
 
 
 
 
 
 
 
 
 
 
633ab26
 
e02c9de
 
 
 
633ab26
 
 
 
 
 
 
e02c9de
 
 
633ab26
e02c9de
 
 
 
 
 
 
 
 
 
 
955241f
e02c9de
6d75162
 
 
e02c9de
6d75162
e02c9de
6d75162
 
 
e02c9de
 
6d75162
 
 
e02c9de
6d75162
 
 
 
 
 
 
 
e02c9de
6d75162
 
 
 
 
633ab26
 
 
 
 
 
 
 
 
 
 
 
 
 
6d75162
e02c9de
6d75162
 
633ab26
6d75162
e02c9de
 
6d75162
 
 
e02c9de
 
 
6d75162
e02c9de
 
 
6d75162
 
e02c9de
 
6d75162
 
 
e02c9de
6d75162
 
 
 
 
 
e02c9de
6d75162
 
 
 
 
633ab26
 
 
 
 
 
 
 
 
 
 
 
 
 
6d75162
e02c9de
6d75162
 
633ab26
6d75162
e02c9de
 
 
6d75162
e02c9de
6d75162
e02c9de
 
 
6d75162
e02c9de
6d75162
 
 
e02c9de
 
6d75162
 
 
 
 
 
 
e02c9de
 
6d75162
e02c9de
 
6d75162
 
 
e02c9de
 
 
6d75162
e02c9de
 
 
 
 
6d75162
e02c9de
 
 
 
 
 
6d75162
e02c9de
 
 
 
a6403d5
e02c9de
6d75162
 
633ab26
6d75162
 
 
 
e02c9de
6d75162
e02c9de
6d75162
e02c9de
 
6d75162
 
a6403d5
e02c9de
 
 
a6403d5
633ab26
a6403d5
 
 
e02c9de
a6403d5
 
e02c9de
a6403d5
 
6d75162
e02c9de
6d75162
955241f
6d75162
 
 
 
633ab26
 
6d75162
 
 
 
 
 
 
 
 
955241f
 
 
 
 
 
6d75162
633ab26
 
6d75162
 
 
 
 
 
e02c9de
6d75162
 
ecd7f95
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
import os
import tempfile
import time
from typing import List, Tuple

import gradio as gr
import torch
import torchaudio
# import spaces
from dataclasses import dataclass
from generator import Segment, load_csm_1b
from huggingface_hub import login


# Disable torch compile feature to avoid triton error
torch._dynamo.config.suppress_errors = True

# Check if GPU is available and configure the device
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Login to Hugging Face Hub if token is available
def login_huggingface():
    hf_token = os.environ.get("HF_TOKEN")
    if hf_token:
        print("Logging in to Hugging Face Hub...")
        login(token=hf_token)
        print("Login successful!")
    else:
        print("HF_TOKEN not found in environment variables. Some models may not be accessible.")

# Login at startup
login_huggingface()

# Global variables to track model state
generator = None
model_loaded = False

# Function to load model in ZeroGPU
# @spaces.GPU(duration=30)
def initialize_model():
    global generator, model_loaded
    if not model_loaded:
        print("Loading CSM-1B model in GPU...")
        generator = load_csm_1b(device="cuda")
        model_loaded = True
        print("Model loaded successfully!")
    return generator

# Function to get the loaded model
# @spaces.GPU(duration=30)
def get_model():
    global generator, model_loaded
    if not model_loaded:
        return initialize_model()
    return generator

# Preload model if environment variable is set
def preload_model_if_needed():
    if os.environ.get("PRELOAD_MODEL", "").lower() in ("true", "1", "yes"):
        print("PRELOAD_MODEL is set. Attempting to preload model...")
        try:
            # We can't directly call initialize_model() here because it's decorated with @spaces.GPU
            # Instead, we'll set a flag that will be checked when the first request comes in
            global model_loaded
            model_loaded = False
            print("Model will be loaded on first request.")
        except Exception as e:
            print(f"Error during model preloading setup: {e}")
    else:
        print("PRELOAD_MODEL is not set. Model will be loaded on demand.")

# Call preload function at startup
preload_model_if_needed()

# Function to convert audio to tensor
def audio_to_tensor(audio_path: str) -> Tuple[torch.Tensor, int]:
    waveform, sample_rate = torchaudio.load(audio_path)
    waveform = waveform.mean(dim=0)  # Convert stereo to mono if needed
    return waveform, sample_rate

# Function to save audio tensor to file
def save_audio(audio_tensor: torch.Tensor, sample_rate: int) -> str:
    # Lưu file vào thư mục hiện tại hoặc thư mục files mà Gradio mặc định sử dụng
    output_path = f"csm1b_output_{int(time.time())}.wav"
    torchaudio.save(output_path, audio_tensor.unsqueeze(0), sample_rate)
    return output_path

# Function to generate speech from text using ZeroGPU
# @spaces.GPU(duration=30)
def generate_speech(
    text: str,
    speaker_id: int,
    context_audio_path1: str = None,
    context_text1: str = None,
    context_speaker1: int = 0,
    context_audio_path2: str = None,
    context_text2: str = None,
    context_speaker2: int = 1,
    max_duration_ms: float = 30000,
    temperature: float = 0.9,
    top_k: int = 50,
    progress=gr.Progress()
) -> str:
    try:
        # Get the loaded model
        generator = get_model()
        
        # Prepare context
        context = []
        progress(0.1, "Processing context...")
        
        # Process context 1
        if context_audio_path1 and context_text1:
            waveform, sample_rate = audio_to_tensor(context_audio_path1)
            # Resample if needed
            if sample_rate != generator.sample_rate:
                waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate)
            context.append(Segment(speaker=context_speaker1, text=context_text1, audio=waveform))
        
        # Process context 2
        if context_audio_path2 and context_text2:
            waveform, sample_rate = audio_to_tensor(context_audio_path2)
            # Resample if needed
            if sample_rate != generator.sample_rate:
                waveform = torchaudio.functional.resample(waveform, orig_freq=sample_rate, new_freq=generator.sample_rate)
            context.append(Segment(speaker=context_speaker2, text=context_text2, audio=waveform))
        
        progress(0.3, "Generating audio...")
        # Generate audio from text
        audio = generator.generate(
            text=text,
            speaker=speaker_id,
            context=context,
            max_audio_length_ms=max_duration_ms,
            # temperature=temperature,
            # topk=top_k
        )
        
        progress(0.8, "Saving audio...")
        # Save audio to file
        # output_path = save_audio(audio, generator.sample_rate)
        output_path = f"csm1b_output_{int(time.time())}.wav"
        
        progress(1.0, "Completed!")
        return output_path
    except Exception as e:
        # Handle ZeroGPU quota exceeded error
        error_message = str(e)
        if "GPU quota exceeded" in error_message:
            # Extract wait time from error message
            import re
            wait_time_match = re.search(r"Try again in (\d+:\d+:\d+)", error_message)
            wait_time = wait_time_match.group(1) if wait_time_match else "some time"
            return f"GPU quota exceeded. Please try again in {wait_time}."
        return f"GPU error: {error_message}"
    except Exception as e:
        return f"Error generating speech: {str(e)}"

# Function to generate simple speech without context
# @spaces.GPU(duration=30)
def generate_speech_simple(
    text: str,
    speaker_id: int,
    max_duration_ms: float = 30000,
    temperature: float = 0.9,
    top_k: int = 50,
    progress=gr.Progress()
) -> str:
    try:
        # Get the loaded model
        generator = get_model()
        
        progress(0.3, "Generating audio...")
        # Generate audio from text
        audio = generator.generate(
            text=text,
            speaker=speaker_id,
            context=[],  # No context
            max_audio_length_ms=max_duration_ms,
            # temperature=temperature,
            # topk=top_k
        )
        
        progress(0.8, "Saving audio...")
        # Save audio to file
        # output_path = save_audio(audio, generator.sample_rate)
        output_path = f"csm1b_output_{int(time.time())}.wav"
        torchaudio.save(output_path, audio.unsqueeze(0).cpu(), generator.sample_rate)


        
        print(f"Audio saved to {output_path}")
        
        progress(1.0, "Completed!")
        return output_path
    except Exception as e:
        # Handle ZeroGPU quota exceeded error
        error_message = str(e)
        if "GPU quota exceeded" in error_message:
            # Extract wait time from error message
            import re
            wait_time_match = re.search(r"Try again in (\d+:\d+:\d+)", error_message)
            wait_time = wait_time_match.group(1) if wait_time_match else "some time"
            return f"GPU quota exceeded. Please try again in {wait_time}."
        return f"GPU error: {error_message}"
    except Exception as e:
        return f"Error generating speech: {str(e)}"

# Create Gradio interface
def create_demo():
    with gr.Blocks(title="CSM-1B Text-to-Speech") as demo:
        gr.Markdown("# CSM-1B Text-to-Speech Demo")
        gr.Markdown("CSM-1B (Collaborative Speech Model) is an advanced text-to-speech model capable of generating natural-sounding speech from text.")
        
        with gr.Tab("Simple Audio Generation"):
            with gr.Row():
                with gr.Column():
                    text_input = gr.Textbox(
                        label="Text to convert to speech",
                        placeholder="Enter the text you want to convert to speech...",
                        lines=5
                    )
                    speaker_id = gr.Number(
                        label="Speaker ID",
                        value=0,
                        precision=0,
                        minimum=0,
                        maximum=10
                    )
                    
                    with gr.Row():
                        max_duration = gr.Slider(
                            label="Maximum Duration (ms)",
                            minimum=1000,
                            maximum=90000,
                            value=30000,
                            step=1000
                        )
                        # temperature = gr.Slider(
                        #     label="Temperature",
                        #     minimum=0.1,
                        #     maximum=1.5,
                        #     value=0.9,
                        #     step=0.1
                        # )
                        # top_k = gr.Slider(
                        #     label="Top-K",
                        #     minimum=1,
                        #     maximum=100,
                        #     value=50,
                        #     step=1
                        # )
                    
                    generate_btn = gr.Button("Generate Audio")
                
                with gr.Column():
                    output_audio = gr.Audio(label="Output Audio", type="filepath", autoplay=True)
        
        with gr.Tab("Audio Generation with Context"):
            gr.Markdown("This feature allows you to provide audio clips and text as context to help the model generate more appropriate speech.")
            
            with gr.Row():
                with gr.Column():
                    context_text1 = gr.Textbox(label="Context Text 1", lines=2)
                    context_audio1 = gr.Audio(label="Context Audio 1", type="filepath")
                    context_speaker1 = gr.Number(label="Speaker ID 1", value=0, precision=0)
                    
                    context_text2 = gr.Textbox(label="Context Text 2", lines=2)
                    context_audio2 = gr.Audio(label="Context Audio 2", type="filepath")
                    context_speaker2 = gr.Number(label="Speaker ID 2", value=1, precision=0)
                    
                    text_input_context = gr.Textbox(
                        label="Text to convert to speech",
                        placeholder="Enter the text you want to convert to speech...",
                        lines=3
                    )
                    speaker_id_context = gr.Number(
                        label="Speaker ID",
                        value=0,
                        precision=0
                    )
                    
                    with gr.Row():
                        max_duration_context = gr.Slider(
                            label="Maximum Duration (ms)",
                            minimum=1000,
                            maximum=90000,
                            value=30000,
                            step=1000
                        )
                        # temperature_context = gr.Slider(
                        #     label="Temperature",
                        #     minimum=0.1,
                        #     maximum=1.5,
                        #     value=0.9,
                        #     step=0.1
                        # )
                        # top_k_context = gr.Slider(
                        #     label="Top-K",
                        #     minimum=1,
                        #     maximum=100,
                        #     value=50,
                        #     step=1
                        # )
                    
                    generate_context_btn = gr.Button("Generate Audio with Context")
                
                with gr.Column():
                    output_audio_context = gr.Audio(label="Output Audio", type="filepath", autoplay=True)
        
        # Add Hugging Face configuration tab
        with gr.Tab("Configuration"):
            gr.Markdown("### Hugging Face Token Configuration")
            gr.Markdown("""
            To use the CSM-1B model, you need access to the model on Hugging Face.
            
            You can configure your token by:
            1. Create a token at [Hugging Face Settings](https://huggingface.co/settings/tokens)
            2. Set the `HF_TOKEN` environment variable with your token value
            
            Note: In Hugging Face Spaces, you can set environment variables in the Space Settings.
            """)
            
            hf_token_input = gr.Textbox(
                label="Hugging Face Token (Only for this session)",
                placeholder="Enter your token...",
                type="password"
            )
            
            def set_token(token):
                if token:
                    os.environ["HF_TOKEN"] = token
                    login(token=token)
                    return "Token set successfully! You can now load the model."
                return "Invalid token. Please enter a valid token."
            
            set_token_btn = gr.Button("Set Token")
            token_status = gr.Textbox(label="Status", interactive=False)
            
            set_token_btn.click(fn=set_token, inputs=hf_token_input, outputs=token_status)
            
        # Add GPU information tab
        with gr.Tab("GPU Information"):
            gr.Markdown("### About ZeroGPU")
            gr.Markdown("""
            This application uses Hugging Face Spaces' ZeroGPU to optimize GPU usage.
            
            ZeroGPU helps free up GPU memory when not in use, saving resources and improving performance.
            
            When you generate audio, the GPU will be used automatically and released after completion.
            
            Note: In the ZeroGPU environment, CUDA is not initialized in the main process, but only in functions with the @spaces.GPU decorator.
            """)
            
            gr.Markdown("### GPU Quota Information")
            gr.Markdown("""
            Hugging Face Spaces has GPU quota limitations:
            
            - Each GPU operation has a default duration of 60 seconds
            - We've reduced this to 30 seconds for audio generation and 10 seconds for GPU checks
            - If you exceed your quota, you'll need to wait for it to reset (usually a few hours)
            - For better performance, try generating shorter audio clips
            
            If you encounter a "GPU quota exceeded" error, please wait for the specified time and try again.
            """)
            
            # @spaces.GPU(duration=10)
            def check_gpu():
                if torch.cuda.is_available():
                    gpu_name = torch.cuda.get_device_name(0)
                    gpu_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3)
                    return f"GPU: {gpu_name}\nMemory: {gpu_memory:.2f} GB"
                else:
                    return "No GPU found. The application will run on CPU."
            
            check_gpu_btn = gr.Button("Check GPU")
            gpu_info = gr.Textbox(label="GPU Information", interactive=False)
            
            check_gpu_btn.click(fn=check_gpu, inputs=None, outputs=gpu_info)
            
            # Add model loading button
            load_model_btn = gr.Button("Load Model")
            model_status = gr.Textbox(label="Model Status", interactive=False)
            
            # @spaces.GPU(duration=10)
            def load_model_and_report():
                global model_loaded
                if model_loaded:
                    return "Model has already been loaded!"
                else:
                    initialize_model()
                    return "Model loaded successfully!"
            
            load_model_btn.click(fn=load_model_and_report, inputs=None, outputs=model_status)
        
        # Connect components
        generate_btn.click(
            fn=generate_speech_simple,
            inputs=[
                text_input,
                speaker_id,
                max_duration,
                # temperature,
                # top_k
            ],
            outputs=output_audio
        )
        
        generate_context_btn.click(
            fn=generate_speech,
            inputs=[
                text_input_context,
                speaker_id_context,
                context_audio1,
                context_text1,
                context_speaker1,
                context_audio2,
                context_text2,
                context_speaker2,
                max_duration_context,
                # temperature_context,
                # top_k_context
            ],
            outputs=output_audio_context
        )
    
    return demo

# Launch the application
if __name__ == "__main__":
    demo = create_demo()
    demo.queue().launch()