import base64 import io import logging from typing import List, Optional import torch import torchaudio import uvicorn from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from generator import load_csm_1b, Segment logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="CSM 1B API", description="API for Sesame's Conversational Speech Model", version="1.0.0", ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) generator = None class SegmentRequest(BaseModel): speaker: int text: str audio_base64: Optional[str] = None class GenerateAudioRequest(BaseModel): text: str speaker: int context: List[SegmentRequest] = [] max_audio_length_ms: float = 10000 temperature: float = 0.9 topk: int = 50 class AudioResponse(BaseModel): audio_base64: str sample_rate: int @app.on_event("startup") async def startup_event(): global generator logger.info("Loading CSM 1B model...") device = "cuda" if torch.cuda.is_available() else "cpu" if device == "cpu": logger.warning("GPU not available. Using CPU, performance may be slow!") try: generator = load_csm_1b(device=device) logger.info(f"Model loaded successfully on device: {device}") except Exception as e: logger.error(f"Could not load model: {str(e)}") raise e @app.post("/generate-audio", response_model=AudioResponse) async def generate_audio(request: GenerateAudioRequest): global generator if generator is None: raise HTTPException(status_code=503, detail="Model not loaded. Please try again later.") try: context_segments = [] for segment in request.context: if segment.audio_base64: audio_bytes = base64.b64decode(segment.audio_base64) audio_buffer = io.BytesIO(audio_bytes) audio_tensor, sample_rate = torchaudio.load(audio_buffer) audio_tensor = torchaudio.functional.resample( audio_tensor.squeeze(0), orig_freq=sample_rate, new_freq=generator.sample_rate ) else: audio_tensor = torch.zeros(0, dtype=torch.float32) context_segments.append( Segment(text=segment.text, speaker=segment.speaker, audio=audio_tensor) ) audio = generator.generate( text=request.text, speaker=request.speaker, context=context_segments, max_audio_length_ms=request.max_audio_length_ms, temperature=request.temperature, topk=request.topk, ) buffer = io.BytesIO() torchaudio.save(buffer, audio.unsqueeze(0).cpu(), generator.sample_rate, format="wav") # torchaudio.save("audio.wav", audio.unsqueeze(0).cpu(), generator.sample_rate) buffer.seek(0) # audio_base64 = base64.b64encode(buffer.read()).decode("utf-8") return AudioResponse( content=buffer.read(), media_type="audio/wav", headers={"Content-Disposition": "attachment; filename=audio.wav"} ) except Exception as e: logger.error(f"error when building audio: {str(e)}") raise HTTPException(status_code=500, detail=f"error when building audio: {str(e)}") @app.get("/health") async def health_check(): if generator is None: return {"status": "not_ready", "message": "Model is loading"} return {"status": "ready", "message": "API is ready to serve"}