import base64 import io import logging from typing import List import torch import torchaudio import gradio as gr import numpy as np from generator import Segment, Model, Generator logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) generator = None def initialize_model(): global generator logger.info("Loading CSM 1B model...") device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cpu": logger.warning("GPU not available. Using CPU, performance may be slow!") logger.info(f"Using device: {device}") try: model = Model.from_pretrained("sesame/csm-1b") model = model.to(device=device) generator = Generator(model) logger.info(f"Model loaded successfully on device: {device}") return True except Exception as e: logger.error(f"Could not load model: {str(e)}") return False def generate_speech(text, speaker_id, max_audio_length_ms=10000, temperature=0.9, topk=50, context_texts=None, context_speakers=None): global generator if generator is None: if not initialize_model(): return None, "Could not load model. Please try again later." try: # Process context if provided context_segments = [] if context_texts and context_speakers: for ctx_text, ctx_speaker in zip(context_texts, context_speakers): if ctx_text and ctx_speaker is not None: context_segments.append( Segment(text=ctx_text, speaker=int(ctx_speaker), audio=torch.zeros(0, dtype=torch.float32)) ) # Generate audio from text audio = generator.generate( text=text, speaker=int(speaker_id), context=context_segments, max_audio_length_ms=float(max_audio_length_ms), temperature=float(temperature), topk=int(topk), ) # Convert tensor to numpy array for Gradio audio_numpy = audio.cpu().numpy() sample_rate = generator.sample_rate return (sample_rate, audio_numpy), None except Exception as e: logger.error(f"Error generating audio: {str(e)}") return None, f"Error generating audio: {str(e)}" def clear_context(): return [], [] def add_context(text, speaker_id, context_texts, context_speakers): if text and speaker_id is not None: context_texts.append(text) context_speakers.append(int(speaker_id)) return context_texts, context_speakers # Set up Gradio interface with gr.Blocks(title="CSM 1B Demo") as demo: gr.Markdown("# CSM 1B - Conversational Speech Model") gr.Markdown("Enter text to generate natural-sounding speech with the CSM 1B model") with gr.Row(): with gr.Column(scale=2): text_input = gr.Textbox( label="Text to convert to speech", placeholder="Enter your text here...", lines=3 ) speaker_id = gr.Slider( label="Speaker ID", minimum=0, maximum=10, step=1, value=0 ) with gr.Accordion("Advanced Options", open=False): max_length = gr.Slider( label="Maximum length (milliseconds)", minimum=1000, maximum=30000, step=1000, value=10000 ) temp = gr.Slider( label="Temperature", minimum=0.1, maximum=1.5, step=0.1, value=0.9 ) top_k = gr.Slider( label="Top K", minimum=10, maximum=100, step=10, value=50 ) with gr.Accordion("Conversation Context", open=False): context_list = gr.State([]) context_speakers_list = gr.State([]) with gr.Row(): context_text = gr.Textbox(label="Context text", lines=2) context_speaker = gr.Slider( label="Context speaker ID", minimum=0, maximum=10, step=1, value=0 ) with gr.Row(): add_ctx_btn = gr.Button("Add Context") clear_ctx_btn = gr.Button("Clear All Context") context_display = gr.Dataframe( headers=["Text", "Speaker ID"], label="Current Context", interactive=False ) generate_btn = gr.Button("Generate Audio", variant="primary") with gr.Column(scale=1): audio_output = gr.Audio(label="Generated Audio", type="numpy") error_output = gr.Textbox(label="Error Message", visible=False) # Connect events generate_btn.click( fn=generate_speech, inputs=[ text_input, speaker_id, max_length, temp, top_k, context_list, context_speakers_list ], outputs=[audio_output, error_output] ) add_ctx_btn.click( fn=add_context, inputs=[ context_text, context_speaker, context_list, context_speakers_list ], outputs=[context_list, context_speakers_list] ) clear_ctx_btn.click( fn=clear_context, inputs=[], outputs=[context_list, context_speakers_list] ) # Update context display def update_context_display(texts, speakers): if not texts or not speakers: return [] return [[text, speaker] for text, speaker in zip(texts, speakers)] context_list.change( fn=update_context_display, inputs=[context_list, context_speakers_list], outputs=[context_display] ) context_speakers_list.change( fn=update_context_display, inputs=[context_list, context_speakers_list], outputs=[context_display] ) # Initialize model when page loads initialize_model() # Configuration for Hugging Face Spaces demo.launch(share=False)