|
import io |
|
import re |
|
import wave |
|
import struct |
|
import os |
|
import time |
|
import json |
|
|
|
import numpy as np |
|
import torch |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse, Response, HTMLResponse |
|
from fastapi.middleware import Middleware |
|
from fastapi.middleware.gzip import GZipMiddleware |
|
|
|
from misaki import en, espeak |
|
|
|
from onnxruntime import InferenceSession |
|
from huggingface_hub import snapshot_download |
|
from scipy.io.wavfile import write as write_wav |
|
|
|
|
|
|
|
|
|
config_file_path = 'config.json' |
|
with open(config_file_path, 'r') as f: |
|
config = json.load(f) |
|
phoneme_vocab = config['vocab'] |
|
|
|
|
|
|
|
|
|
model_repo = "onnx-community/Kokoro-82M-v1.0-ONNX" |
|
model_name = "onnx/model_q4.onnx" |
|
voice_file_pattern = "*.bin" |
|
local_dir = "." |
|
snapshot_download( |
|
repo_id=model_repo, |
|
allow_patterns=[model_name, voice_file_pattern], |
|
local_dir=local_dir |
|
) |
|
|
|
|
|
|
|
|
|
model_path = os.path.join(local_dir, model_name) |
|
sess = InferenceSession(model_path) |
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
title="Kokoro TTS FastAPI", |
|
middleware=[Middleware(GZipMiddleware, compresslevel=9)] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes: |
|
""" |
|
Generate a WAV header for streaming. Since we do not know the final audio size, |
|
a large dummy value is used for the data chunk size. |
|
""" |
|
bits_per_sample = sample_width * 8 |
|
byte_rate = sample_rate * num_channels * sample_width |
|
block_align = num_channels * sample_width |
|
total_size = 36 + data_size |
|
header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE') |
|
fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) |
|
data_chunk_header = struct.pack('<4sI', b'data', data_size) |
|
return header + fmt_chunk + data_chunk_header |
|
|
|
stream_header = generate_wav_header(24000, 1, 2) |
|
|
|
def custom_split_text(text: str) -> list: |
|
""" |
|
Custom splitting strategy: |
|
- Start with a chunk size of 2 words. |
|
- For each chunk, if a period (".") is found in any word (except the very last word), |
|
then split at that word (including it). |
|
- Otherwise, use the current chunk size. |
|
- Increase the chunk size by 2 for each subsequent chunk. |
|
- If there are fewer than the desired number of words remaining, include all of them. |
|
""" |
|
words = text.split() |
|
chunks = [] |
|
chunk_size = 2 |
|
start = 0 |
|
while start < len(words): |
|
candidate_end = start + chunk_size |
|
if candidate_end > len(words): |
|
candidate_end = len(words) |
|
chunk_words = words[start:candidate_end] |
|
split_index = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chunks.append(" ".join(chunk_words)) |
|
start = candidate_end |
|
if chunk_size < 100: |
|
chunk_size += 2 |
|
return chunks |
|
|
|
|
|
def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes: |
|
""" |
|
Convert a torch.FloatTensor (values in [-1, 1]) to raw 16-bit PCM bytes. |
|
""" |
|
audio_np = audio_tensor.cpu().numpy() |
|
if audio_np.ndim > 1: |
|
audio_np = audio_np.flatten() |
|
audio_int16 = np.int16(audio_np * 32767) |
|
return audio_int16.tobytes() |
|
|
|
|
|
def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes: |
|
""" |
|
Convert a torch.FloatTensor to Opus-encoded bytes. |
|
Requires the 'opuslib' package: pip install opuslib |
|
""" |
|
try: |
|
import opuslib |
|
except ImportError: |
|
raise ImportError("opuslib is not installed. Please install it with: pip install opuslib") |
|
|
|
audio_np = audio_tensor.cpu().numpy() |
|
if audio_np.ndim > 1: |
|
audio_np = audio_np.flatten() |
|
audio_int16 = np.int16(audio_np * 32767) |
|
|
|
encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) |
|
frame_size = int(sample_rate * 0.020) |
|
encoded_data = b'' |
|
for i in range(0, len(audio_int16), frame_size): |
|
frame = audio_int16[i:i + frame_size] |
|
if len(frame) < frame_size: |
|
frame = np.pad(frame, (0, frame_size - len(frame)), 'constant') |
|
encoded_frame = encoder.encode(frame.tobytes(), frame_size) |
|
encoded_data += encoded_frame |
|
return encoded_data |
|
|
|
fbs = espeak.EspeakFallback(british=True) |
|
g2p = en.G2P(trf=False, british=False, fallback=fbs) |
|
|
|
def tokenizer(text: str): |
|
""" |
|
Converts text to a list of phoneme tokens using the global vocabulary. |
|
""" |
|
phonemes_string, tokens = g2p(text) |
|
phonemes = [ph for ph in phonemes_string] |
|
print(text + " " + phonemes_string) |
|
tokens = [phoneme_vocab[phoneme] for phoneme in phonemes if phoneme in phoneme_vocab] |
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/tts/streaming", summary="Streaming TTS") |
|
def tts_streaming(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"): |
|
""" |
|
Streaming TTS endpoint. |
|
|
|
This endpoint splits the input text into chunks (using the doubling strategy), |
|
then for each chunk: |
|
- For the first chunk, a 0 is prepended. |
|
- For subsequent chunks, the first token is set to the last token from the previous chunk. |
|
- For the final chunk, a 0 is appended. |
|
|
|
The audio for each chunk is generated immediately and streamed to the client. |
|
""" |
|
chunks = custom_split_text(text) |
|
|
|
|
|
voice_path = os.path.join(local_dir, f"voices/{voice}.bin") |
|
if not os.path.exists(voice_path): |
|
raise HTTPException(status_code=404, detail="Voice file not found") |
|
voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256) |
|
|
|
def audio_generator(): |
|
|
|
if format.lower() == "wav": |
|
yield stream_header |
|
|
|
prev_last_token = None |
|
for i, chunk in enumerate(chunks): |
|
|
|
chunk_tokens = tokenizer(chunk) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prev_last_token = chunk_tokens[-1:] |
|
|
|
|
|
tokens_to_send = [0] + chunk_tokens + [0] |
|
|
|
final_token = [tokens_to_send] |
|
print(final_token) |
|
|
|
|
|
style_index = len(chunk_tokens) + 2 |
|
if style_index >= len(voices): |
|
style_index = len(voices) - 1 |
|
ref_s = voices[style_index] |
|
|
|
|
|
speed_param = np.ones(1, dtype=np.float32) * speed |
|
|
|
|
|
try: |
|
start_time = time.time() |
|
audio_output = sess.run(None, { |
|
"input_ids": final_token, |
|
"style": ref_s, |
|
"speed": speed_param, |
|
})[0] |
|
print(f"Chunk {i} inference time: {time.time() - start_time:.3f}s") |
|
except Exception as e: |
|
print(f"Error processing chunk {i}: {e}") |
|
|
|
audio_output = np.zeros((24000,), dtype=np.float32) |
|
|
|
|
|
audio_int16 = (audio_output * 32767).astype(np.int16).flatten()[6000:-3000] |
|
print(audio_int16) |
|
|
|
|
|
|
|
|
|
|
|
yield audio_int16.tobytes() |
|
|
|
media_type = "audio/wav" |
|
return StreamingResponse( |
|
audio_generator(), |
|
media_type=media_type, |
|
headers={"Cache-Control": "no-cache"}, |
|
) |
|
|
|
|
|
@app.get("/tts/full", summary="Full TTS") |
|
def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"): |
|
""" |
|
Full TTS endpoint that synthesizes the entire text and returns a complete WAV or Opus file. |
|
""" |
|
voice_path = os.path.join(local_dir, f"voices/{voice}.bin") |
|
voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256) |
|
|
|
tokens = tokenizer(text) |
|
ref_s = voices[len(tokens)] |
|
final_token = [[0, *tokens, 0]] |
|
|
|
start_time = time.time() |
|
audio = sess.run(None, { |
|
"input_ids": final_token, |
|
"style": ref_s, |
|
"speed": np.ones(1, dtype=np.float32) * speed, |
|
})[0] |
|
print(f"Full TTS inference time: {time.time()-start_time:.3f}s") |
|
|
|
|
|
audio = (audio * 32767).astype(np.int16).flatten() |
|
|
|
if format.lower() == "wav": |
|
wav_io = io.BytesIO() |
|
write_wav(wav_io, 24000, audio) |
|
wav_io.seek(0) |
|
return Response(content=wav_io.read(), media_type="audio/wav") |
|
elif format.lower() == "opus": |
|
opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(audio.astype(np.float32)/32767), sample_rate=24000) |
|
return Response(content=opus_data, media_type="audio/opus") |
|
else: |
|
raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}") |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
def index(): |
|
""" |
|
HTML demo page for Kokoro TTS. |
|
""" |
|
return """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Kokoro TTS Demo</title> |
|
</head> |
|
<body> |
|
<h1>Kokoro TTS Demo</h1> |
|
<textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br> |
|
<label for="voice">Voice:</label> |
|
<input type="text" id="voice" value="af_heart"><br> |
|
<label for="speed">Speed:</label> |
|
<input type="number" step="0.1" id="speed" value="1.0"><br> |
|
<label for="format">Format:</label> |
|
<select id="format"> |
|
<option value="wav">WAV</option> |
|
<option value="opus" selected>Opus</option> |
|
</select><br><br> |
|
<button onclick="playStreaming()">Play Streaming TTS</button> |
|
<button onclick="playFull()">Play Full TTS</button> |
|
<br><br> |
|
<audio id="audio" controls autoplay></audio> |
|
<script> |
|
function playStreaming() { |
|
const text = document.getElementById('text').value; |
|
const voice = document.getElementById('voice').value; |
|
const speed = document.getElementById('speed').value; |
|
const format = document.getElementById('format').value; |
|
const audio = document.getElementById('audio'); |
|
audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`; |
|
audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus'; |
|
audio.play(); |
|
} |
|
function playFull() { |
|
const text = document.getElementById('text').value; |
|
const voice = document.getElementById('voice').value; |
|
const speed = document.getElementById('speed').value; |
|
const format = document.getElementById('format').value; |
|
const audio = document.getElementById('audio'); |
|
audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`; |
|
audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus'; |
|
audio.play(); |
|
} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |
|
|