File size: 3,295 Bytes
5ca847f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import io
import os
import tempfile
from typing import List, Optional

import TTS.api
import torch
from pydub import AudioSegment

from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import StreamingResponse, Response

import config

device = "cuda" if torch.cuda.is_available() else "cpu"

models = {}

for id, model in config.models.items():
    models[id] = TTS.api.TTS(model).to(device)

class SynthesizeResponse(Response):
    media_type = 'audio/wav'

app = FastAPI()

@app.post('/tts', response_class=SynthesizeResponse)
async def synthesize(
    text: str = Form('Hello,World!'), 
    speaker_wavs: List[UploadFile] = File(None), 
    speaker_idx: str = Form('Ana Florence'), 
    language: str = Form('ja'), 
    temperature: float = Form(0.65), 
    length_penalty: float = Form(1.0), 
    repetition_penalty: float = Form(2.0), 
    top_k: int = Form(50), 
    top_p: float = Form(0.8), 
    speed: float = Form(1.0), 
    enable_text_splitting: bool = Form(True)
) -> StreamingResponse:
    temp_files = []
    try:
        if speaker_wavs:
            # Process each uploaded file
            for speaker_wav in speaker_wavs:
                speaker_wav_bytes = await speaker_wav.read()
                # Convert the uploaded audio file to a WAV format using pydub
                try:
                    audio = AudioSegment.from_file(io.BytesIO(speaker_wav_bytes))
                    wav_buffer = io.BytesIO()
                    audio.export(wav_buffer, format="wav")
                    wav_buffer.seek(0)  # Reset buffer position to the beginning
                except Exception as e:
                    raise HTTPException(status_code=400, detail=f"Error processing audio file: {e}")

                temp_wav_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
                temp_wav_file.write(wav_buffer.read())
                temp_wav_file.close()
                temp_files.append(temp_wav_file.name)

        output_buffer = io.BytesIO()
        if temp_files:
            models['multi'].tts_to_file(
                text=text,
                speaker_wav=temp_files, 
                language=language,
                file_path=output_buffer, 
                temperature=temperature,
                length_penalty=length_penalty,
                repetition_penalty=repetition_penalty,
                top_k=top_k,
                top_p=top_p,
                speed=speed,
                enable_text_splitting=enable_text_splitting
            )
        else:
            models['multi'].tts_to_file(
                text=text,
                speaker=speaker_idx, 
                language=language,
                file_path=output_buffer, 
                temperature=temperature,
                length_penalty=length_penalty,
                repetition_penalty=repetition_penalty,
                top_k=top_k,
                top_p=top_p,
                speed=speed,
                enable_text_splitting=enable_text_splitting
            )
        output_buffer.seek(0)
        return StreamingResponse(output_buffer, media_type="audio/wav")
    finally:
        for temp_file in temp_files:
            if isinstance(temp_file, str) and os.path.exists(temp_file):
                os.remove(temp_file)