Spaces:
Running
Running
from flask import Flask, request, jsonify, send_from_directory, abort | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
import librosa | |
import torch | |
import numpy as np | |
from onnxruntime import InferenceSession | |
import soundfile as sf | |
import os | |
import sys | |
import uuid | |
import logging | |
from flask_cors import CORS | |
import threading | |
import tempfile | |
from huggingface_hub import snapshot_download | |
from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError | |
import time | |
from tts_processor import preprocess_all | |
import hashlib | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = Flask(__name__) | |
CORS(app, resources={r"/*": {"origins": "*"}}) | |
# Global lock to ensure one method runs at a time | |
global_lock = threading.Lock() | |
# Repository ID and paths | |
kokoro_model_id = 'onnx-community/Kokoro-82M-v1.0-ONNX' | |
model_path = 'kokoro_model' | |
voice_name = 'am_adam' # Example voice: af (adjust as needed) | |
# Directory to serve files from | |
SERVE_DIR = os.environ.get("SERVE_DIR", "./files") # Default to './files' if not provided | |
os.makedirs(SERVE_DIR, exist_ok=True) | |
def validate_audio_file(file): | |
if file.content_type not in ["audio/wav", "audio/x-wav", "audio/mpeg", "audio/mp3"]: | |
raise ValueError("Unsupported file type") | |
file.seek(0, os.SEEK_END) | |
file_size = file.tell() | |
file.seek(0) # Reset file pointer | |
if file_size > 10 * 1024 * 1024: # 10 MB limit | |
raise ValueError("File is too large (max 10 MB)") | |
def validate_text_input(text): | |
if not isinstance(text, str): | |
raise ValueError("Text input must be a string") | |
if len(text.strip()) == 0: | |
raise ValueError("Text input cannot be empty") | |
if len(text) > 1024: # Limit to 1024 characters | |
raise ValueError("Text input is too long (max 1024 characters)") | |
file_cache = {} | |
def is_cached(cached_file_path): | |
""" | |
Check if a file exists in the cache. | |
If the file is not in the cache, perform a disk check and update the cache. | |
""" | |
if cached_file_path in file_cache: | |
return file_cache[cached_file_path] # Return cached result | |
exists = os.path.exists(cached_file_path) # Perform disk check | |
file_cache[cached_file_path] = exists # Update the cache | |
return exists | |
import time | |
from huggingface_hub import snapshot_download | |
from huggingface_hub.utils import RepositoryNotFoundError, HfHubHTTPError | |
def initialize_models(): | |
global sess, voice_style, processor, whisper_model | |
max_retries = 5 # Maximum number of retries | |
retry_delay = 2 # Initial delay in seconds (will double after each retry) | |
for attempt in range(max_retries): | |
try: | |
# Download the ONNX model if not already downloaded | |
if not os.path.exists(model_path): | |
logger.info(f"Attempt {attempt + 1} to download and load Kokoro model...") | |
kokoro_dir = snapshot_download(kokoro_model_id, cache_dir=model_path) | |
logger.info(f"Kokoro model directory: {kokoro_dir}") | |
else: | |
kokoro_dir = model_path | |
logger.info(f"Using cached Kokoro model directory: {kokoro_dir}") | |
# Validate ONNX file path | |
onnx_path = None | |
for root, _, files in os.walk(kokoro_dir): | |
if 'model.onnx' in files: | |
onnx_path = os.path.join(root, 'model.onnx') | |
break | |
if not onnx_path or not os.path.exists(onnx_path): | |
raise FileNotFoundError(f"ONNX file not found after redownload at {kokoro_dir}") | |
logger.info("Loading ONNX session...") | |
sess = InferenceSession(onnx_path) | |
logger.info(f"ONNX session loaded successfully from {onnx_path}") | |
# Load the voice style vector | |
voice_style_path = None | |
for root, _, files in os.walk(kokoro_dir): | |
if f'{voice_name}.bin' in files: | |
voice_style_path = os.path.join(root, f'{voice_name}.bin') | |
break | |
if not voice_style_path or not os.path.exists(voice_style_path): | |
raise FileNotFoundError(f"Voice style file not found at {voice_style_path}") | |
logger.info("Loading voice style vector...") | |
voice_style = np.fromfile(voice_style_path, dtype=np.float32).reshape(-1, 1, 256) | |
logger.info(f"Voice style vector loaded successfully from {voice_style_path}") | |
# Initialize Whisper model for S2T | |
logger.info("Downloading and loading Whisper model...") | |
processor = WhisperProcessor.from_pretrained("openai/whisper-base") | |
whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-base") | |
whisper_model.config.forced_decoder_ids = None | |
logger.info("Whisper model loaded successfully") | |
# If everything succeeds, break out of the retry loop | |
break | |
except (RepositoryNotFoundError, HfHubHTTPError, FileNotFoundError) as e: | |
logger.error(f"Attempt {attempt + 1} failed: {str(e)}") | |
if attempt == max_retries - 1: | |
logger.error("Max retries reached. Failed to initialize models.") | |
raise # Re-raise the exception if max retries are reached | |
time.sleep(retry_delay) | |
retry_delay *= 2 # Exponential backoff | |
# Initialize models | |
initialize_models() | |
# Health check endpoint | |
def health_check(): | |
try: | |
return jsonify({"status": "healthy"}), 200 | |
except Exception as e: | |
logger.error(f"Health check failed: {str(e)}") | |
return jsonify({"status": "unhealthy"}), 500 | |
# Text-to-Speech (T2S) Endpoint | |
def generate_audio(): | |
"""Text-to-Speech (T2S) Endpoint""" | |
with global_lock: # Acquire global lock to ensure only one instance runs | |
try: | |
logger.debug("Received request to /generate_audio") | |
data = request.json | |
text = data['text'] | |
output_dir = data.get('output_dir') | |
validate_text_input(text) | |
logger.debug(f"Text: {text}") | |
if not output_dir: | |
raise ValueError("Output directory is required but not provided") | |
# Ensure output_dir is an absolute path and valid | |
if not os.path.isabs(output_dir): | |
raise ValueError("Output directory must be an absolute path") | |
if not os.path.exists(output_dir): | |
raise ValueError(f"Output directory does not exist: {output_dir}") | |
# Generate a unique hash for the text | |
text = preprocess_all(text) | |
logger.debug(f"Processed Text {text}") | |
text_hash = hashlib.sha256(text.encode('utf-8')).hexdigest() | |
hashed_file_name = f"{text_hash}.wav" | |
cached_file_path = os.path.join(output_dir, hashed_file_name) | |
logger.debug(f"Generated hash for processed text: {text_hash}") | |
logger.debug(f"Output directory: {output_dir}") | |
logger.debug(f"Cached file path: {cached_file_path}") | |
# Check if cached file exists | |
if is_cached(cached_file_path): | |
logger.info(f"Returning cached audio for text: {text}") | |
return jsonify({"status": "success", "output_path": cached_file_path}) | |
# Tokenize text | |
logger.debug("Tokenizing text...") | |
from kokoro import phonemize, tokenize # Import dynamically | |
tokens = tokenize(phonemize(text, 'a')) | |
logger.debug(f"Initial tokens: {tokens}") | |
if len(tokens) > 510: | |
logger.warning("Text too long; truncating to 510 tokens.") | |
tokens = tokens[:510] | |
tokens = [[0, *tokens, 0]] # Add pad tokens | |
logger.debug(f"Final tokens: {tokens}") | |
# Get style vector based on token length | |
logger.debug("Fetching style vector...") | |
ref_s = voice_style[len(tokens[0]) - 2] # Shape: (1, 256) | |
logger.debug(f"Style vector shape: {ref_s.shape}") | |
# Run ONNX inference | |
logger.debug("Running ONNX inference...") | |
audio = sess.run(None, dict( | |
input_ids=np.array(tokens, dtype=np.int64), | |
style=ref_s, | |
speed=np.ones(1, dtype=np.float32), | |
))[0] | |
logger.debug(f"Audio generated with shape: {audio.shape}") | |
# Fix audio data for saving | |
audio = np.squeeze(audio) # Remove extra dimension | |
audio = audio.astype(np.float32) # Ensure correct data type | |
# Save audio | |
logger.debug(f"Saving audio to {cached_file_path}...") | |
sf.write(cached_file_path, audio, 24000) # Save with 24 kHz sample rate | |
logger.info(f"Audio saved successfully to {cached_file_path}") | |
return jsonify({"status": "success", "output_path": cached_file_path}) | |
except Exception as e: | |
logger.error(f"Error generating audio: {str(e)}") | |
return jsonify({"status": "error", "message": str(e)}), 500 | |
# Speech-to-Text (S2T) Endpoint | |
def transcribe_audio(): | |
"""Speech-to-Text (S2T) Endpoint""" | |
with global_lock: # Acquire global lock to ensure only one instance runs | |
audio_path = None | |
try: | |
logger.debug("Received request to /transcribe_audio") | |
file = request.files['file'] | |
validate_audio_file(file) | |
# Generate a unique filename using uuid | |
unique_filename = f"{uuid.uuid4().hex}_{file.filename}" | |
audio_path = os.path.join("/tmp", unique_filename) | |
file.save(audio_path) | |
logger.debug(f"Audio file saved to {audio_path}") | |
# Load and preprocess audio | |
logger.debug("Processing audio for transcription...") | |
audio_array, sampling_rate = librosa.load(audio_path, sr=16000) | |
input_features = processor( | |
audio_array, | |
sampling_rate=sampling_rate, | |
return_tensors="pt" | |
).input_features | |
# Generate transcription | |
logger.debug("Generating transcription...") | |
predicted_ids = whisper_model.generate(input_features) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
logger.info(f"Transcription: {transcription}") | |
return jsonify({"status": "success", "transcription": transcription}) | |
except Exception as e: | |
logger.error(f"Error transcribing audio: {str(e)}") | |
return jsonify({"status": "error", "message": str(e)}), 500 | |
finally: | |
# Ensure temporary file is removed | |
if audio_path and os.path.exists(audio_path): | |
os.remove(audio_path) | |
logger.debug(f"Temporary file {audio_path} removed") | |
def serve_wav_file(filename): | |
""" | |
Serve a .wav file from the configured directory. | |
Only serves files ending with '.wav'. | |
""" | |
# Ensure only .wav files are allowed | |
if not filename.lower().endswith('.wav'): | |
abort(400, "Only .wav files are allowed.") | |
# Check if the file exists in the directory | |
file_path = os.path.join(SERVE_DIR, filename) | |
logger.debug(f"Looking for file at: {file_path}") | |
if not os.path.isfile(file_path): | |
logger.error(f"File not found: {file_path}") | |
abort(404, "File not found.") | |
# Serve the file | |
return send_from_directory(SERVE_DIR, filename) | |
# Error handlers | |
def bad_request(error): | |
"""Handle 400 errors.""" | |
return {"error": "Bad Request", "message": str(error)}, 400 | |
def not_found(error): | |
"""Handle 404 errors.""" | |
return {"error": "Not Found", "message": str(error)}, 404 | |
def internal_error(error): | |
"""Handle unexpected errors.""" | |
return {"error": "Internal Server Error", "message": "An unexpected error occurred."}, 500 | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=7860) | |