import base64 import io import logging from typing import List import os import sys import numpy as np import gradio as gr # Import các module cần thiết try: import torch import torchaudio HAS_TORCH = True except ImportError: HAS_TORCH = False logging.warning("PyTorch not available. Using mock generator.") # Tạo lớp Mock để sử dụng khi không có PyTorch hoặc model bị lỗi class MockGenerator: def __init__(self): self.sample_rate = 24000 logging.info("Created mock generator with sample rate 24000") def generate(self, text, speaker, context=None, max_audio_length_ms=10000, temperature=0.9, topk=50): # Tạo âm thanh giả - chỉ là silence với độ dài tỷ lệ với text duration_seconds = min(len(text) * 0.1, max_audio_length_ms / 1000) samples = int(duration_seconds * self.sample_rate) logging.info(f"Generating mock audio with {samples} samples") return np.zeros(samples, dtype=np.float32) # Định nghĩa lớp Segment giả khi cần class MockSegment: def __init__(self, text, speaker, audio=None): self.text = text self.speaker = speaker self.audio = audio if audio is not None else np.zeros(0, dtype=np.float32) logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) generator = None def initialize_model(): global generator logger.info("Loading CSM 1B model...") # Nếu không có PyTorch, sử dụng mock if not HAS_TORCH: logger.warning("PyTorch not available. Using mock generator.") generator = MockGenerator() return True # Có PyTorch, thử tải model thật try: # Kiểm tra và tải các thư viện cần thiết import sys # Thêm thư mục hiện tại vào PATH để đảm bảo import được các module cần thiết if os.getcwd() not in sys.path: sys.path.append(os.getcwd()) # Thử import từ generator module (theo hướng dẫn chính thức) try: from generator import load_csm_1b, Segment device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cpu": logger.warning("GPU not available. Using CPU, performance may be slow!") logger.info(f"Using device: {device}") # Tải model theo cách chính thức generator = load_csm_1b(device=device) logger.info(f"Model loaded successfully on device: {device}") return True except Exception as e: logger.error(f"Error loading model: {str(e)}") # Tải mock generator trong trường hợp lỗi logger.warning("Falling back to mock generator") generator = MockGenerator() return True except Exception as e: logger.error(f"Critical error: {str(e)}") generator = MockGenerator() return True def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9, topk=50, context_texts=None, context_speakers=None): global generator if generator is None: if not initialize_model(): # Sử dụng mock generator nếu không khởi tạo được generator = MockGenerator() try: # Xác định Segment class để sử dụng try: from generator import Segment except ImportError: Segment = MockSegment # Xử lý context nếu có context_segments = [] if context_texts and context_speakers: for ctx_text, ctx_speaker in zip(context_texts, context_speakers): if ctx_text and ctx_speaker is not None: # Tạo audio tensor rỗng cho context if HAS_TORCH: audio_tensor = torch.zeros(0, dtype=torch.float32) else: audio_tensor = np.zeros(0, dtype=np.float32) context_segments.append( Segment(text=ctx_text, speaker=int(ctx_speaker), audio=audio_tensor) ) # Generate audio từ text audio = generator.generate( text=text, speaker=int(speaker_id), context=context_segments, max_audio_length_ms=float(max_audio_length_ms), temperature=float(temperature), topk=int(topk), ) # Chuyển đổi tensor sang numpy array cho Gradio if HAS_TORCH and isinstance(audio, torch.Tensor): audio_numpy = audio.cpu().numpy() else: audio_numpy = audio # Đã là numpy từ MockGenerator sample_rate = generator.sample_rate return (sample_rate, audio_numpy), None except Exception as e: logger.error(f"Error generating audio: {str(e)}") # Sử dụng mock generator trong trường hợp lỗi mock_gen = MockGenerator() audio = mock_gen.generate(text=text, speaker=int(speaker_id), max_audio_length_ms=float(max_audio_length_ms)) return (mock_gen.sample_rate, audio), f"Error generating audio, using silent audio: {str(e)}" def clear_context(): return [], [] def add_context(text, speaker_id, context_texts, context_speakers): if text and speaker_id is not None: context_texts.append(text) context_speakers.append(int(speaker_id)) return context_texts, context_speakers def update_context_display(texts, speakers): if not texts or not speakers: return [] return [[text, speaker] for text, speaker in zip(texts, speakers)] def create_demo(): # Set up Gradio interface demo = gr.Blocks(title="CSM 1B Demo") with demo: gr.Markdown("# CSM 1B - Conversational Speech Model") gr.Markdown("Enter text to generate natural-sounding speech with the CSM 1B model") if not HAS_TORCH: gr.Markdown("⚠️ **WARNING: PyTorch is not available. Using a mock generator that produces silent audio.**") with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="Text to convert to speech", placeholder="Enter your text here...", lines=3 ) speaker_id = gr.Slider( label="Speaker ID", minimum=0, maximum=10, step=1, value=0 ) with gr.Accordion("Advanced Options", open=False): max_length = gr.Slider( label="Maximum length (milliseconds)", minimum=1000, maximum=30000, step=1000, value=10000 ) temp = gr.Slider( label="Temperature", minimum=0.1, maximum=1.5, step=0.1, value=0.9 ) top_k = gr.Slider( label="Top K", minimum=10, maximum=100, step=10, value=50 ) with gr.Accordion("Conversation Context", open=False): context_list = gr.State([]) context_speakers_list = gr.State([]) with gr.Row(): context_text = gr.Textbox(label="Context text", lines=2) context_speaker = gr.Slider( label="Context speaker ID", minimum=0, maximum=10, step=1, value=0 ) with gr.Row(): add_ctx_btn = gr.Button("Add Context") clear_ctx_btn = gr.Button("Clear All Context") context_display = gr.Dataframe( headers=["Text", "Speaker ID"], label="Current Context", interactive=False ) generate_btn = gr.Button("Generate Audio", variant="primary") with gr.Column(scale=1): audio_output = gr.Audio(label="Generated Audio", type="numpy") error_output = gr.Textbox(label="Error Message", visible=False) # Connect events generate_btn.click( fn=generate_speech, inputs=[ text_input, speaker_id, max_length, temp, top_k, context_list, context_speakers_list ], outputs=[audio_output, error_output] ) add_ctx_btn.click( fn=add_context, inputs=[ context_text, context_speaker, context_list, context_speakers_list ], outputs=[context_list, context_speakers_list] ).then( fn=update_context_display, inputs=[context_list, context_speakers_list], outputs=[context_display] ) clear_ctx_btn.click( fn=clear_context, inputs=[], outputs=[context_list, context_speakers_list] ).then( fn=lambda: [], inputs=[], outputs=[context_display] ) gr.Markdown(""" ## About CSM-1B CSM (Conversational Speech Model) is a speech generation model from Sesame that generates audio from text inputs. The model can generate a variety of voices and works best when provided with conversational context. ### Features: - Generate natural-sounding speech from text - Choose different speaker identities (0-10) - Adjust temperature to control output variability - Add conversation context for more natural responses [View on Hugging Face](https://huggingface.co/sesame/csm-1b) | [GitHub Repository](https://github.com/SesameAILabs/csm) """) return demo # Khởi tạo model initialize_model() # Tạo và khởi chạy demo demo = create_demo() demo.launch(server_name="0.0.0.0", server_port=7860, share=True)