|
import io |
|
import re |
|
import wave |
|
import struct |
|
|
|
import numpy as np |
|
import torch |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse, Response, HTMLResponse |
|
|
|
from kokoro import KPipeline |
|
|
|
app = FastAPI(title="Kokoro TTS FastAPI") |
|
|
|
|
|
|
|
|
|
pipeline = KPipeline(lang_code="a") |
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes: |
|
""" |
|
Generate a WAV header for streaming. |
|
Since we don't know the final audio size, we set the data chunk size to a large dummy value. |
|
""" |
|
bits_per_sample = sample_width * 8 |
|
byte_rate = sample_rate * num_channels * sample_width |
|
block_align = num_channels * sample_width |
|
total_size = 36 + data_size |
|
header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE') |
|
fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) |
|
data_chunk_header = struct.pack('<4sI', b'data', data_size) |
|
return header + fmt_chunk + data_chunk_header |
|
|
|
|
|
def custom_split_text(text: str) -> list: |
|
""" |
|
Custom splitting: |
|
- Start with a chunk size of 2 words. |
|
- For each chunk, if a period (".") is found in any word (except if it’s the very last word), |
|
then split the chunk at that word (include words up to that word). |
|
- Otherwise, use the current chunk size. |
|
- For subsequent chunks, increase the chunk size by 2. |
|
- If there are fewer than the desired number of words for a full chunk, add all remaining words. |
|
""" |
|
words = text.split() |
|
chunks = [] |
|
chunk_size = 2 |
|
start = 0 |
|
while start < len(words): |
|
candidate_end = start + chunk_size |
|
if candidate_end > len(words): |
|
candidate_end = len(words) |
|
chunk_words = words[start:candidate_end] |
|
|
|
split_index = None |
|
for i in range(len(chunk_words) - 1): |
|
if '.' in chunk_words[i]: |
|
split_index = i |
|
break |
|
if split_index is not None: |
|
candidate_end = start + split_index + 1 |
|
chunk_words = words[start:candidate_end] |
|
chunks.append(" ".join(chunk_words)) |
|
start = candidate_end |
|
chunk_size += 2 |
|
return chunks |
|
|
|
|
|
def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes: |
|
""" |
|
Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes. |
|
""" |
|
audio_np = audio_tensor.cpu().numpy() |
|
if audio_np.ndim > 1: |
|
audio_np = audio_np.flatten() |
|
audio_int16 = np.int16(audio_np * 32767) |
|
return audio_int16.tobytes() |
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/tts/streaming", summary="Streaming TTS") |
|
def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0): |
|
""" |
|
Streaming TTS endpoint that returns a continuous WAV stream. |
|
|
|
This endpoint first yields a WAV header (with a dummy data length) and then yields raw PCM data |
|
for each text chunk as soon as it is generated. |
|
""" |
|
chunks = custom_split_text(text) |
|
sample_rate = 24000 |
|
num_channels = 1 |
|
sample_width = 2 |
|
|
|
def audio_generator(): |
|
|
|
header = generate_wav_header(sample_rate, num_channels, sample_width) |
|
yield header |
|
|
|
for i, chunk in enumerate(chunks): |
|
print(f"Processing chunk {i}: {chunk}") |
|
try: |
|
results = list(pipeline(chunk, voice=voice, speed=speed, split_pattern=None)) |
|
for result in results: |
|
if result.audio is not None: |
|
print(f"Chunk {i}: Audio generated") |
|
yield audio_tensor_to_pcm_bytes(result.audio) |
|
else: |
|
print(f"Chunk {i}: No audio generated") |
|
except Exception as e: |
|
print(f"Error processing chunk {i}: {e}") |
|
|
|
return StreamingResponse( |
|
audio_generator(), |
|
media_type="audio/wav", |
|
headers={"Cache-Control": "no-cache"}, |
|
) |
|
|
|
|
|
@app.get("/tts/full", summary="Full TTS") |
|
def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0): |
|
""" |
|
Full TTS endpoint that synthesizes the entire text, concatenates the audio, |
|
and returns a complete WAV file. |
|
""" |
|
results = list(pipeline(text, voice=voice, speed=speed, split_pattern=r"\n+")) |
|
audio_segments = [] |
|
for result in results: |
|
if result.audio is not None: |
|
audio_np = result.audio.cpu().numpy() |
|
if audio_np.ndim > 1: |
|
audio_np = audio_np.flatten() |
|
audio_segments.append(audio_np) |
|
|
|
if not audio_segments: |
|
raise HTTPException(status_code=500, detail="No audio generated.") |
|
|
|
full_audio = np.concatenate(audio_segments) |
|
|
|
sample_rate = 24000 |
|
num_channels = 1 |
|
sample_width = 2 |
|
wav_io = io.BytesIO() |
|
with wave.open(wav_io, "wb") as wav_file: |
|
wav_file.setnchannels(num_channels) |
|
wav_file.setsampwidth(sample_width) |
|
wav_file.setframerate(sample_rate) |
|
full_audio_int16 = np.int16(full_audio * 32767) |
|
wav_file.writeframes(full_audio_int16.tobytes()) |
|
wav_io.seek(0) |
|
return Response(content=wav_io.read(), media_type="audio/wav") |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
def index(): |
|
""" |
|
HTML demo page for Kokoro TTS. |
|
|
|
Two playback methods are provided: |
|
- "Play Streaming TTS" sets the <audio> element's src to the streaming endpoint. |
|
- "Play Full TTS" sets the <audio> element's src to the full synthesis endpoint. |
|
The browser’s native playback handles streaming (progressive download) of the WAV data. |
|
""" |
|
return """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Kokoro TTS Demo</title> |
|
</head> |
|
<body> |
|
<h1>Kokoro TTS Demo</h1> |
|
<textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br> |
|
<label for="voice">Voice:</label> |
|
<input type="text" id="voice" value="af_heart"><br> |
|
<label for="speed">Speed:</label> |
|
<input type="number" step="0.1" id="speed" value="1.0"><br><br> |
|
<button onclick="playStreaming()">Play Streaming TTS</button> |
|
<button onclick="playFull()">Play Full TTS</button> |
|
<br><br> |
|
<audio id="audioPlayer" controls autoplay></audio> |
|
<script> |
|
function playStreaming() { |
|
const text = document.getElementById('text').value; |
|
const voice = document.getElementById('voice').value; |
|
const speed = document.getElementById('speed').value; |
|
const audio = document.getElementById('audioPlayer'); |
|
// Simply point the audio element to the streaming endpoint. |
|
audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`; |
|
audio.play(); |
|
} |
|
function playFull() { |
|
const text = document.getElementById('text').value; |
|
const voice = document.getElementById('voice').value; |
|
const speed = document.getElementById('speed').value; |
|
const audio = document.getElementById('audioPlayer'); |
|
// Simply point the audio element to the full synthesis endpoint. |
|
audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}`; |
|
audio.play(); |
|
} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |
|
|