|
import io |
|
import re |
|
import wave |
|
import struct |
|
|
|
import numpy as np |
|
import torch |
|
from fastapi import FastAPI, HTTPException |
|
from fastapi.responses import StreamingResponse, Response, HTMLResponse |
|
from fastapi.middleware import Middleware |
|
from fastapi.middleware.gzip import GZipMiddleware |
|
|
|
from misaki import en |
|
|
|
import os |
|
import numpy as np |
|
from onnxruntime import InferenceSession |
|
from huggingface_hub import snapshot_download |
|
|
|
import json |
|
from scipy.io.wavfile import write as write_wav |
|
|
|
import time |
|
|
|
|
|
config_file_path = 'config.json' |
|
|
|
with open(config_file_path, 'r') as f: |
|
config = json.load(f) |
|
|
|
|
|
phoneme_vocab = config['vocab'] |
|
|
|
|
|
model_repo = "onnx-community/Kokoro-82M-v1.0-ONNX" |
|
model_name = "onnx/model_q8f16.onnx" |
|
voice_file_pattern = "*.bin" |
|
local_dir = "." |
|
|
|
|
|
snapshot_download( |
|
repo_id=model_repo, |
|
allow_patterns=[model_name, voice_file_pattern], |
|
local_dir=local_dir |
|
) |
|
|
|
|
|
model_path = os.path.join(local_dir, model_name) |
|
sess = InferenceSession(model_path) |
|
|
|
app = FastAPI( |
|
title="Kokoro TTS FastAPI", |
|
middleware=[ |
|
Middleware(GZipMiddleware, compresslevel=9) |
|
] |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_wav_header(sample_rate: int, num_channels: int, sample_width: int, data_size: int = 0x7FFFFFFF) -> bytes: |
|
""" |
|
Generate a WAV header for streaming. |
|
Since we don't know the final audio size, we set the data chunk size to a large dummy value. |
|
This header is sent only once at the start of the stream. |
|
""" |
|
bits_per_sample = sample_width * 8 |
|
byte_rate = sample_rate * num_channels * sample_width |
|
block_align = num_channels * sample_width |
|
|
|
total_size = 36 + data_size |
|
header = struct.pack('<4sI4s', b'RIFF', total_size, b'WAVE') |
|
fmt_chunk = struct.pack('<4sIHHIIHH', b'fmt ', 16, 1, num_channels, sample_rate, byte_rate, block_align, bits_per_sample) |
|
data_chunk_header = struct.pack('<4sI', b'data', data_size) |
|
return header + fmt_chunk + data_chunk_header |
|
|
|
|
|
def custom_split_text(text: str) -> list: |
|
""" |
|
Custom splitting: |
|
- Start with a chunk size of 2 words. |
|
- For each chunk, if a period (".") is found in any word (except if it’s the very last word), |
|
then split the chunk at that word (include words up to that word). |
|
- Otherwise, use the current chunk size. |
|
- For subsequent chunks, increase the chunk size by 2. |
|
- If there are fewer than the desired number of words for a full chunk, add all remaining words. |
|
""" |
|
words = text.split() |
|
chunks = [] |
|
chunk_size = 2 |
|
start = 0 |
|
while start < len(words): |
|
candidate_end = start + chunk_size |
|
if candidate_end > len(words): |
|
candidate_end = len(words) |
|
chunk_words = words[start:candidate_end] |
|
|
|
split_index = None |
|
for i in range(len(chunk_words) - 1): |
|
if '.' in chunk_words[i]: |
|
split_index = i |
|
break |
|
if split_index is not None: |
|
candidate_end = start + split_index + 1 |
|
chunk_words = words[start:candidate_end] |
|
chunks.append(" ".join(chunk_words)) |
|
start = candidate_end |
|
chunk_size += 2 |
|
return chunks |
|
|
|
|
|
def audio_tensor_to_pcm_bytes(audio_tensor: torch.Tensor) -> bytes: |
|
""" |
|
Convert a torch.FloatTensor (with values in [-1, 1]) to raw 16-bit PCM bytes. |
|
""" |
|
|
|
audio_np = audio_tensor.cpu().numpy() |
|
if audio_np.ndim > 1: |
|
audio_np = audio_np.flatten() |
|
|
|
audio_int16 = np.int16(audio_np * 32767) |
|
return audio_int16.tobytes() |
|
|
|
|
|
def audio_tensor_to_opus_bytes(audio_tensor: torch.Tensor, sample_rate: int = 24000, bitrate: int = 32000) -> bytes: |
|
""" |
|
Convert a torch.FloatTensor to Opus encoded bytes. |
|
Requires the 'opuslib' package: pip install opuslib |
|
""" |
|
try: |
|
import opuslib |
|
except ImportError: |
|
raise ImportError("opuslib is not installed. Please install it with: pip install opuslib") |
|
|
|
audio_np = audio_tensor.cpu().numpy() |
|
if audio_np.ndim > 1: |
|
audio_np = audio_np.flatten() |
|
|
|
audio_int16 = np.int16(audio_np * 32767) |
|
|
|
encoder = opuslib.Encoder(sample_rate, 1, opuslib.APPLICATION_VOIP) |
|
|
|
|
|
frame_size = int(sample_rate * 0.020) |
|
|
|
encoded_data = b'' |
|
for i in range(0, len(audio_int16), frame_size): |
|
frame = audio_int16[i:i + frame_size] |
|
if len(frame) < frame_size: |
|
|
|
frame = np.pad(frame, (0, frame_size - len(frame)), 'constant') |
|
encoded_frame = encoder.encode(frame.tobytes(), frame_size) |
|
encoded_data += encoded_frame |
|
|
|
return encoded_data |
|
|
|
g2p = en.G2P(trf=False, british=False, fallback=None) |
|
|
|
def tokenizer(text): |
|
print("Text: " + text) |
|
phonemes_string, _ = g2p(text) |
|
phonemes = [] |
|
for i in phonemes_string: |
|
phonemes.append(i) |
|
tokens = [phoneme_vocab[phoneme] for phoneme in phonemes if phoneme in phoneme_vocab] |
|
print(tokens) |
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/tts/full", summary="Full TTS") |
|
def tts_full(text: str, voice: str = "af_heart", speed: float = 1.0, format: str = "wav"): |
|
""" |
|
Full TTS endpoint that synthesizes the entire text, concatenates the audio, |
|
and returns a complete WAV or Opus file. |
|
""" |
|
voice_path = os.path.join(local_dir, f"voices/{voice}.bin") |
|
voices = np.fromfile(voice_path, dtype=np.float32).reshape(-1, 1, 256) |
|
|
|
tokens = tokenizer(text) |
|
|
|
ref_s = voices[len(tokens)] |
|
|
|
final_token = [[0, *tokens, 0]] |
|
|
|
start_time = time.time() |
|
|
|
audio = sess.run(None, dict( |
|
input_ids=final_token, |
|
style=ref_s, |
|
speed=np.ones(1, dtype=np.float32), |
|
))[0] |
|
|
|
print(time.time()-start_time) |
|
|
|
|
|
sample_rate = 24000 |
|
|
|
|
|
audio = (audio * 32767).astype(np.int16) |
|
|
|
|
|
audio = audio.flatten() |
|
|
|
if format.lower() == "wav": |
|
|
|
|
|
wav_io = io.BytesIO() |
|
|
|
|
|
write_wav(wav_io, sample_rate, audio) |
|
|
|
|
|
wav_io.seek(0) |
|
|
|
return Response(content=wav_io.read(), media_type="audio/wav") |
|
elif format.lower() == "opus": |
|
opus_data = audio_tensor_to_opus_bytes(torch.from_numpy(audio), sample_rate=sample_rate) |
|
return Response(content=opus_data, media_type="audio/opus") |
|
else: |
|
raise HTTPException(status_code=400, detail=f"Unsupported audio format: {format}") |
|
|
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
def index(): |
|
""" |
|
HTML demo page for Kokoro TTS. |
|
|
|
This page provides a simple UI to enter text, choose a voice and speed, |
|
and play synthesized audio from both the streaming and full endpoints. |
|
""" |
|
return """ |
|
<!DOCTYPE html> |
|
<html> |
|
<head> |
|
<title>Kokoro TTS Demo</title> |
|
</head> |
|
<body> |
|
<h1>Kokoro TTS Demo</h1> |
|
<textarea id="text" rows="4" cols="50" placeholder="Enter text here"></textarea><br> |
|
<label for="voice">Voice:</label> |
|
<input type="text" id="voice" value="af_heart"><br> |
|
<label for="speed">Speed:</label> |
|
<input type="number" step="0.1" id="speed" value="1.0"><br> |
|
<label for="format">Format:</label> |
|
<select id="format"> |
|
<option value="wav">WAV</option> |
|
<option value="opus" selected>Opus</option> |
|
</select><br><br> |
|
<button onclick="playStreaming()">Play Streaming TTS</button> |
|
<button onclick="playFull()">Play Full TTS</button> |
|
<br><br> |
|
<audio id="audio" controls autoplay></audio> |
|
<script> |
|
function playStreaming() { |
|
const text = document.getElementById('text').value; |
|
const voice = document.getElementById('voice').value; |
|
const speed = document.getElementById('speed').value; |
|
const format = document.getElementById('format').value; |
|
const audio = document.getElementById('audio'); |
|
// Set the audio element's source to the streaming endpoint. |
|
audio.src = `/tts/streaming?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`; |
|
audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus'; |
|
audio.play(); |
|
} |
|
function playFull() { |
|
const text = document.getElementById('text').value; |
|
const voice = document.getElementById('voice').value; |
|
const speed = document.getElementById('speed').value; |
|
const format = document.getElementById('format').value; |
|
const audio = document.getElementById('audio'); |
|
// Set the audio element's source to the full TTS endpoint. |
|
audio.src = `/tts/full?text=${encodeURIComponent(text)}&voice=${encodeURIComponent(voice)}&speed=${speed}&format=${format}`; |
|
audio.type = format === 'wav' ? 'audio/wav' : 'audio/opus'; |
|
audio.play(); |
|
} |
|
</script> |
|
</body> |
|
</html> |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |