Michael Hu
add more logging
c72d839
raw
history blame
3.26 kB
import os
import torch
import time
import logging
from pydub import AudioSegment
from phonemizer.backend.espeak.wrapper import EspeakWrapper
from models import build_model
logger = logging.getLogger(__name__)
# Hugging Face Spaces setup
MODEL_DIR = "./kokoro"
os.makedirs(MODEL_DIR, exist_ok=True)
# Configure espeak-ng for Hugging Face environment
EspeakWrapper.set_library('/usr/lib/x86_64-linux-gnu/libespeak-ng.so.1')
class TTSEngine:
def __init__(self):
logger.info("Initializing TTS Engine")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.debug(f"Using device: {self.device}")
self._verify_model_files()
logger.info("Loading Kokoro model")
self.model = build_model(f"{MODEL_DIR}/kokoro-v0_19.pth", self.device)
logger.info("Loading voice model")
self.voice = torch.load(f"{MODEL_DIR}/voices/af_bella.pt",
map_location=self.device)
logger.info("TTS engine initialized")
def _verify_model_files(self):
"""Ensure required model files exist"""
required_files = [
f"{MODEL_DIR}/kokoro-v0_19.pth",
f"{MODEL_DIR}/voices/af_bella.pt"
]
missing = [f for f in required_files if not os.path.exists(f)]
if missing:
logger.error(f"Missing model files: {missing}")
raise FileNotFoundError(
f"Missing model files: {missing}\n"
"Add this to your Hugging Face Space settings:\n"
"App setup -> Clone Kokoro repository: "
"git clone https://huggingface.co/hexgrad/Kokoro-82M ./kokoro"
)
def generate_speech(self, text: str, language: str = "zh") -> str:
"""Generate speech from Chinese text"""
logger.info(f"Generating speech for text length: {len(text)}")
try:
from kokoro import generate_full
if len(text) > 500:
logger.warning(f"Truncating long text ({len(text)} characters)")
text = text[:495] + "[TRUNCATED]"
logger.debug("Starting audio generation")
audio, _ = generate_full(
self.model,
text,
self.voice,
lang='en-us',
max_len=200 if self.device == "cpu" else 500
)
output_path = f"temp/outputs/output_{int(time.time())}.wav"
logger.debug(f"Saving audio to {output_path}")
AudioSegment(
audio.numpy().tobytes(),
frame_rate=24000,
sample_width=2,
channels=1
).export(output_path, format="wav")
logger.info(f"Audio generation complete: {output_path}")
return output_path
except Exception as e:
logger.error(f"TTS generation failed: {str(e)}", exc_info=True)
raise
# Initialize TTS engine once
@st.cache_resource
def get_tts_engine():
return TTSEngine()
def generate_speech(text: str, language: str = "zh") -> str:
"""Public interface for TTS generation"""
return get_tts_engine().generate_speech(text, language)