import io import os import tempfile from typing import List, Optional import TTS.api import torch from pydub import AudioSegment from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import StreamingResponse, Response import config device = "cuda" if torch.cuda.is_available() else "cpu" models = {} for id, model in config.models.items(): models[id] = TTS.api.TTS(model).to(device) class SynthesizeResponse(Response): media_type = 'audio/wav' app = FastAPI() @app.post('/tts', response_class=SynthesizeResponse) async def synthesize( text: str = Form('Hello,World!'), speaker_wavs: List[UploadFile] = File(None), speaker_idx: str = Form('Ana Florence'), language: str = Form('ja'), temperature: float = Form(0.65), length_penalty: float = Form(1.0), repetition_penalty: float = Form(2.0), top_k: int = Form(50), top_p: float = Form(0.8), speed: float = Form(1.0), enable_text_splitting: bool = Form(True) ) -> StreamingResponse: temp_files = [] try: if speaker_wavs: # Process each uploaded file for speaker_wav in speaker_wavs: speaker_wav_bytes = await speaker_wav.read() # Convert the uploaded audio file to a WAV format using pydub try: audio = AudioSegment.from_file(io.BytesIO(speaker_wav_bytes)) wav_buffer = io.BytesIO() audio.export(wav_buffer, format="wav") wav_buffer.seek(0) # Reset buffer position to the beginning except Exception as e: raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}") temp_wav_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) temp_wav_file.write(wav_buffer.read()) temp_wav_file.close() temp_files.append(temp_wav_file.name) output_buffer = io.BytesIO() if temp_files: models['multi'].tts_to_file( text=text, speaker_wav=temp_files, language=language, file_path=output_buffer, temperature=temperature, length_penalty=length_penalty, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, speed=speed, enable_text_splitting=enable_text_splitting ) else: models['multi'].tts_to_file( text=text, speaker=speaker_idx, language=language, file_path=output_buffer, temperature=temperature, length_penalty=length_penalty, repetition_penalty=repetition_penalty, top_k=top_k, top_p=top_p, speed=speed, enable_text_splitting=enable_text_splitting ) output_buffer.seek(0) return StreamingResponse(output_buffer, media_type="audio/wav") finally: for temp_file in temp_files: if isinstance(temp_file, str) and os.path.exists(temp_file): os.remove(temp_file)