Spaces:
Sleeping
Sleeping
import os | |
os.environ["HOME"] = "/root" | |
os.environ["HF_HOME"] = "/tmp/hf_cache" | |
import logging | |
import threading | |
import tempfile | |
import uuid | |
import torch | |
import numpy as np | |
import soundfile as sf | |
import torchaudio | |
import wave | |
import time | |
import re | |
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, BackgroundTasks | |
from fastapi.responses import JSONResponse | |
from fastapi.staticfiles import StaticFiles | |
from typing import Dict, Any, Optional, Tuple, List | |
from(datetime import datetime, timedelta) | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger("talklas-api") | |
app = FastAPI(title="Talklas API") | |
# Mount a directory to serve audio files | |
AUDIO_DIR = "/tmp/audio_output" # Use /tmp for temporary files | |
os.makedirs(AUDIO_DIR, exist_ok=True) | |
app.mount("/audio_output", StaticFiles(directory=AUDIO_DIR), name="audio_output") | |
# Global variables to track application state | |
models_loaded = False | |
loading_in_progress = False | |
loading_thread = None | |
model_status = { | |
"stt": "not_loaded", | |
"mt": "not_loaded", | |
"tts": "not_loaded" | |
} | |
error_message = None | |
current_tts_language = "tgl" # Track the current TTS language | |
# Model instances | |
stt_processor_whisper = None | |
stt_model_whisper = None | |
stt_processor_mms = None | |
stt_model_mms = None | |
mt_model = None | |
mt_tokenizer = None | |
tts_model = None | |
tts_tokenizer = None | |
# Define the valid languages and mappings | |
LANGUAGE_MAPPING = { | |
"English": "eng", | |
"Tagalog": "tgl", | |
"Cebuano": "ceb", | |
"Ilocano": "ilo", | |
"Waray": "war", | |
"Pangasinan": "pag" | |
} | |
# Mapping for Whisper language names | |
WHISPER_LANGUAGE_MAPPING = { | |
"eng": "english", | |
"tgl": "tagalog" | |
} | |
NLLB_LANGUAGE_CODES = { | |
"eng": "eng_Latn", | |
"tgl": "tgl_Latn", | |
"ceb": "ceb_Latn", | |
"ilo": "ilo_Latn", | |
"war": "war_Latn", | |
"pag": "pag_Latn" | |
} | |
# List of inappropriate words/phrases for content filtering | |
INAPPROPRIATE_WORDS = [ | |
# English inappropriate words | |
"fuck", "shit", "bitch", "ass", "damn", "hell", "bastard", "cunt", "son of a bitch", "dick", "pussy", "motherfucker", | |
# Philippine languages | |
"agka baboy", "puta", "putang ina", "gago", "tanga", "hayop", "ulol", "lintik", "animal ka", | |
"paki", "pakyu", "yawa", "bungol", "gingan", "yawa ka", "peste", "irig", | |
"pakit", "ayat", "pua", "kayat mo ti agsardeng", "hinampak", "iring ka" | |
] | |
# Function to check for inappropriate content | |
def check_inappropriate_content(text: str) -> bool: | |
""" | |
Check if the text contains inappropriate content. | |
Returns True if inappropriate content is detected, False otherwise. | |
""" | |
text_lower = text.lower() | |
for word in INAPPROPRIATE_WORDS: | |
pattern = r'\b' + re.escape(word) + r'\b' | |
if re.search(pattern, text_lower): | |
logger.warning(f"Inappropriate content detected: {word}") | |
return True | |
return False | |
# Function to save PCM data as a WAV file | |
def save_pcm_to_wav(pcm_data: list, sample_rate: int, output_path: str): | |
pcm_array = np.array(pcm_data, dtype=np.int16) | |
with wave.open(output_path, 'wb') as wav_file: | |
wav_file.setnchannels(1) | |
wav_file.setsampwidth(2) | |
wav_file.setframerate(sample_rate) | |
wav_file.writeframes(pcm_array.tobytes()) | |
# Function to detect speech using an energy-based approach | |
def detect_speech(waveform: torch.Tensor, sample_rate: int, threshold: float = 0.01, min_speech_duration: float = 0.5) -> bool: | |
waveform_np = waveform.numpy() | |
if waveform_np.ndim > 1: | |
waveform_np = waveform_np.mean(axis=0) | |
rms = np.sqrt(np.mean(waveform_np**2)) | |
logger.info(f"RMS energy: {rms}") | |
if rms < threshold: | |
logger.info("No speech detected: RMS energy below threshold") | |
return False | |
return True | |
# Function to clean up old audio files | |
def cleanup_old_audio_files(): | |
logger.info("Starting cleanup of old audio files...") | |
expiration_time = datetime.now() - timedelta(minutes=10) | |
for filename in os.listdir(AUDIO_DIR): | |
file_path = os.path.join(AUDIO_DIR, filename) | |
if os.path.isfile(file_path): | |
file_mtime = datetime.fromtimestamp(os.path.getmtime(file_path)) | |
if file_mtime < expiration_time: | |
try: | |
os.unlink(file_path) | |
logger.info(f"Deleted old audio file: {file_path}") | |
except Exception as e: | |
logger.error(f"Error deleting file {file_path}: {str(e)}") | |
# Background task to periodically clean up audio files | |
def schedule_cleanup(): | |
while True: | |
cleanup_old_audio_files() | |
time.sleep(300) | |
# Function to load models in background | |
def load_models_task(): | |
global models_loaded, loading_in_progress, model_status, error_message | |
global stt_processor_whisper, stt_model_whisper, stt_processor_mms, stt_model_mms | |
global mt_model, mt_tokenizer, tts_model, tts_tokenizer | |
try: | |
loading_in_progress = True | |
# Load STT models | |
logger.info("Starting to load STT models...") | |
from transformers import AutoProcessor, AutoModelForCTC, WhisperProcessor, WhisperForConditionalGeneration | |
try: | |
logger.info("Loading Whisper STT model...") | |
model_status["stt"] = "loading" | |
stt_processor_whisper = WhisperProcessor.from_pretrained("openai/whisper-tiny") | |
stt_model_whisper = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
stt_model_whisper.to(device) | |
logger.info("Whisper STT model loaded successfully") | |
model_status["stt"] = "loaded_whisper" | |
except Exception as e: | |
logger.error(f"Failed to load Whisper STT model: {str(e)}") | |
model_status["stt"] = "failed" | |
error_message = f"Whisper STT model loading failed: {str(e)}" | |
return | |
try: | |
logger.info("Loading MMS STT model...") | |
stt_processor_mms = AutoProcessor.from_pretrained("facebook/mms-1b-all") | |
stt_model_mms = AutoModelForCTC.from_pretrained("facebook/mms-1b-all") | |
stt_model_mms.to(device) | |
logger.info("MMS STT model loaded successfully") | |
model_status["stt"] = "loaded_both" if model_status["stt"] == "loaded_whisper" else "loaded_mms" | |
except Exception as e: | |
logger.error(f"Failed to load MMS STT model: {str(e)}") | |
if model_status["stt"] != "loaded_whisper": | |
model_status["stt"] = "failed" | |
error_message = f"MMS STT model loading failed: {str(e)}" | |
return | |
# Load MT model | |
logger.info("Starting to load MT model...") | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
try: | |
logger.info("Loading NLLB-200-distilled-600M model...") | |
model_status["mt"] = "loading" | |
mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
mt_model.to(device) | |
logger.info("MT model loaded successfully") | |
model_status["mt"] = "loaded" | |
except Exception as e: | |
logger.error(f"Failed to load MT model: {str(e)}") | |
model_status["mt"] = "failed" | |
error_message = f"MT model loading failed: {str(e)}" | |
return | |
# Load TTS model (default to Tagalog) | |
logger.info("Starting to load TTS model...") | |
from transformers import VitsModel, AutoTokenizer | |
try: | |
logger.info("Loading MMS-TTS model for Tagalog...") | |
model_status["tts"] = "loading" | |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-tgl") | |
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-tgl") | |
tts_model.to(device) | |
logger.info("TTS model loaded successfully") | |
model_status["tts"] = "loaded" | |
except Exception as e: | |
logger.error(f"Failed to load TTS model for Tagalog: {str(e)}") | |
try: | |
logger.info("Falling back to MMS-TTS English model...") | |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
tts_model.to(device) | |
logger.info("Fallback TTS model loaded successfully") | |
model_status["tts"] = "loaded (fallback)" | |
current_tts_language = "eng" | |
except Exception as e2: | |
logger.error(f"Failed to load fallback TTS model: {str(e2)}") | |
model_status["tts"] = "failed" | |
error_message = f"TTS model loading failed: {str(e)} (fallback also failed: {str(e2)})" | |
return | |
models_loaded = True | |
logger.info("Model loading completed successfully") | |
except Exception as e: | |
error_message = str(e) | |
logger.error(f"Error in model loading task: {str(e)}") | |
finally: | |
loading_in_progress = False | |
# Start loading models in background | |
def start_model_loading(): | |
global loading_thread, loading_in_progress | |
if not loading_in_progress and not models_loaded: | |
loading_in_progress = True | |
loading_thread = threading.Thread(target=load_models_task) | |
loading_thread.daemon = True | |
loading_thread.start() | |
# Start the background cleanup task | |
def start_cleanup_task(): | |
cleanup_thread = threading.Thread(target=schedule_cleanup) | |
cleanup_thread.daemon = True | |
cleanup_thread.start() | |
# Function to load or update TTS model for a specific language | |
def load_tts_model_for_language(target_code: str) -> bool: | |
global tts_model, tts_tokenizer, current_tts_language, model_status | |
if target_code not in LANGUAGE_MAPPING.values(): | |
logger.error(f"Invalid language code: {target_code}") | |
return False | |
if current_tts_language == target_code and model_status["tts"].startswith("loaded"): | |
logger.info(f"TTS model for {target_code} is already loaded.") | |
return True | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
try: | |
logger.info(f"Loading MMS-TTS model for {target_code}...") | |
from transformers import VitsModel, AutoTokenizer | |
tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{target_code}") | |
tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{target_code}") | |
tts_model.to(device) | |
current_tts_language = target_code | |
logger.info(f"TTS model updated to {target_code}") | |
model_status["tts"] = "loaded" | |
return True | |
except Exception as e: | |
logger.error(f"Failed to load TTS model for {target_code}: {str(e)}") | |
try: | |
logger.info("Falling back to MMS-TTS English model...") | |
tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
tts_model.to(device) | |
current_tts_language = "eng" | |
logger.info("Fallback TTS model loaded successfully") | |
model_status["tts"] = "loaded (fallback)" | |
return True | |
except Exception as e2: | |
logger.error(f"Failed to load fallback TTS model: {str(e2)}") | |
model_status["tts"] = "failed" | |
return False | |
# Function to synthesize speech from text | |
def synthesize_speech(text: str, target_code: str) -> Tuple[Optional[str], Optional[str]]: | |
global tts_model, tts_tokenizer | |
request_id = str(uuid.uuid4()) | |
output_path = os.path.join(AUDIO_DIR, f"{request_id}.wav") | |
if not load_tts_model_for_language(target_code): | |
return None, "Failed to load TTS model for the target language" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
try: | |
inputs = tts_tokenizer(text, return_tensors="pt").toagli(device) | |
with torch.no_grad(): | |
output = tts_model(**inputs) | |
speech = output.waveform.cpu().numpy().squeeze() | |
speech = (speech * 32767).astype(np.int16) | |
sample_rate = tts_model.config.sampling_rate | |
save_pcm_to_wav(speech.tolist(), sample_rate, output_path) | |
logger.info(f"Saved synthesized audio to {output_path}") | |
return output_path, None | |
except Exception as e: | |
error_msg = f"Error during TTS conversion: {str(e)}" | |
logger.error(error_msg) | |
return None, error_msg | |
# Start the background processes when the app starts | |
async def startup_event(): | |
logger.info("Application starting up...") | |
start_model_loading() | |
start_cleanup_task() | |
async def root(): | |
logger.info("Root endpoint requested") | |
return {"status": "healthy"} | |
async def health_check(): | |
logger.info("Health check requested") | |
return { | |
"status": "healthy", | |
"models_loaded": models_loaded, | |
"loading_in_progress": loading_in_progress, | |
"model_status": model_status, | |
"error": error_message | |
} | |
async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)): | |
global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language | |
if not text: | |
raise HTTPException(status_code=400, detail="No text provided") | |
if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING: | |
raise HTTPException(status_code=400, detail="Invalid language selected") | |
logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}") | |
request_id = str(uuid.uuid4()) | |
source_code = LANGUAGE_MAPPING[source_lang] | |
target_code = LANGUAGE_MAPPING[target_lang] | |
translated_text = "Translation not available" | |
if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None: | |
try: | |
source_nllb_code = NLLB_LANGUAGE_CODES[source_code] | |
target_nllb_code = NLLB_LANGUAGE_CODES[target_code] | |
mt_tokenizer.src_lang = source_nllb_code | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
inputs = mt_tokenizer(text, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
generated_tokens = mt_model.generate( | |
**inputs, | |
forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code), | |
max_length=448 | |
) | |
translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
logger.info(f"Translation completed: {translated_text}") | |
except Exception as e: | |
logger.error(f"Error during translation: {str(e)}") | |
translated_text = f"Translation failed: {str(e)}" | |
else: | |
logger.warning("MT model not loaded, skipping translation") | |
is_inappropriate = check_inappropriate_content(text) or check_inappropriate_content(translated_text) | |
if is_inappropriate: | |
logger.warning("Inappropriate content detected in translation request") | |
output_audio_url = None | |
if model_status["tts"].startswith("loaded"): | |
if load_tts_model_for_language(target_code): | |
try: | |
output_path, error = synthesize_speech(translated_text, target_code) | |
if output_path: | |
output_filename = os.path.basename(output_path) | |
output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}" | |
logger.info("TTS conversion completed") | |
except Exception as e: | |
logger.error(f"Error during TTS conversion: {str(e)}") | |
return { | |
"request_id": request_id, | |
"status": "completed", | |
"message": "Translation and TTS completed (or partially completed).", | |
"source_text": text, | |
"translated_text": translated_text, | |
"output_audio": output_audio_url, | |
"is_inappropriate": is_inappropriate | |
} | |
async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)): | |
global stt_processor_whisper, stt_model_whisper, stt_processor_mms, stt_model_mms | |
global mt_model, mt_tokenizer, tts_model, tts_tokenizer, current_tts_language | |
if not audio: | |
raise HTTPException(status_code=400, detail="No audio file provided") | |
if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING: | |
raise HTTPException(status_code=400, detail="Invalid language selected") | |
logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}") | |
request_id = str(uuid.uuid4()) | |
source_code = LANGUAGE_MAPPING[source_lang] | |
use_whisper = source_code in ["eng", "tgl"] | |
# Check if appropriate STT model is loaded | |
if use_whisper and (stt_processor_whisper is None or stt_model_whisper is None): | |
logger.warning("Whisper STT model not loaded, returning placeholder response") | |
return { | |
"request_id": request_id, | |
"status": "processing", | |
"message": "Whisper STT model not loaded yet. Please try again later.", | |
"source_text": "Transcription not available", | |
"translated_text": "Translation not available", | |
"output_audio": None, | |
"is_inappropriate": False | |
} | |
elif not use_whisper and (stt_processor_mms is None or stt_model_mms is None): | |
logger.warning("MMS STT model not loaded, returning placeholder response") | |
return { | |
"request_id": request_id, | |
"status": "processing", | |
"message": "MMS STT model not loaded yet. Please try again later.", | |
"source_text": "Transcription not available", | |
"translated_text": "Translation not available", | |
"output_audio": None, | |
"is_inappropriate": False | |
} | |
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file: | |
temp_file.write(await audio.read()) | |
temp_path = temp_file.name | |
transcription = "Transcription not available" | |
translated_text = "Translation not available" | |
output_audio_url = None | |
is_inappropriate = False | |
try: | |
logger.info(f"Reading audio file: {temp_path}") | |
waveform, sample_rate = torchaudio.load(temp_path) | |
logger.info(f"Audio loaded: sample_rate={sample_rate}, waveform_shape={waveform.shape}") | |
if sample_rate != 16000: | |
logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz") | |
resampler = torchaudio.transforms.Resample(sample_rate, 16000) | |
waveform = resampler(waveform) | |
sample_rate = 16000 | |
if not detect_speech(waveform, sample_rate): | |
return { | |
"request_id": request_id, | |
"status": "failed", | |
"message": "No speech detected in the audio.", | |
"source_text": "No speech detected", | |
"translated_text": "No translation available", | |
"output_audio": None, | |
"is_inappropriate": False | |
} | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
logger.info(f"Using device: {device}") | |
if use_whisper: | |
logger.info("Using Whisper model for transcription") | |
whisper_lang = WHISPER_LANGUAGE_MAPPING.get(source_code, "english") # Default to English if not mapped | |
inputs = stt_processor_whisper(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
generated_ids = stt_model_whisper.generate(**inputs, language=whisper_lang) | |
transcription = stt_processor_whisper.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
else: | |
logger.info("Using MMS model for transcription") | |
inputs = stt_processor_mms(waveform.numpy(), sampling_rate=16000, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
logits = stt_model_mms(**inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = stt_processor_mms.batch_decode(predicted_ids)[0] | |
logger.info(f"Transcription completed: {transcription}") | |
target_code = LANGUAGE_MAPPING[target_lang] | |
if model_status["mt"] == "loaded" and mt_model is not None and mt_tokenizer is not None: | |
try: | |
source_nllb_code = NLLB_LANGUAGE_CODES[source_code] | |
target_nllb_code = NLLB_LANGUAGE_CODES[target_code] | |
mt_tokenizer.src_lang = source_nllb_code | |
inputs = mt_tokenizer(transcription, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
generated_tokens = mt_model.generate( | |
**inputs, | |
forced_bos_token_id=mt_tokenizer.convert_tokens_to_ids(target_nllb_code), | |
max_length=448 | |
) | |
translated_text = mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
logger.info(f"Translation completed: {translated_text}") | |
except Exception as e: | |
logger.error(f"Error during translation: {str(e)}") | |
translated_text = f"Translation failed: {str(e)}" | |
else: | |
logger.warning("MT model not loaded, skipping translation") | |
is_inappropriate = check_inappropriate_content(transcription) or check_inappropriate_content(translated_text) | |
if is_inappropriate: | |
logger.warning("Inappropriate content detected in audio transcription or translation") | |
if load_tts_model_for_language(target_code): | |
try: | |
output_path, error = synthesize_speech(translated_text, target_code) | |
if output_path: | |
output_filename = os.path.basename(output_path) | |
output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}" | |
logger.info("TTS conversion completed") | |
except Exception as e: | |
logger.error(f"Error during TTS conversion: {str(e)}") | |
return { | |
"request_id": request_id, | |
"status": "completed", | |
"message": "Transcription, translation, and TTS completed (or partially completed).", | |
"source_text": transcription, | |
"translated_text": translated_text, | |
"output_audio": output_audio_url, | |
"is_inappropriate": is_inappropriate | |
} | |
except Exception as e: | |
logger.error(f"Error during processing: {str(e)}") | |
return { | |
"request_id": request_id, | |
"status": "failed", | |
"message": f"Processing failed: {str(e)}", | |
"source_text": transcription, | |
"translated_text": translated_text, | |
"output_audio": output_audio_url, | |
"is_inappropriate": is_inappropriate | |
} | |
finally: | |
logger.info(f"Cleaning up temporary file: {temp_path}") | |
os.unlink(temp_path) | |
async def text_to_speech(text: str = Form(...), target_lang: str = Form(...)): | |
if not text: | |
raise HTTPException(status_code=400, detail="No text provided") | |
if target_lang not in LANGUAGE_MAPPING: | |
raise HTTPException(status_code=400, detail="Invalid language selected") | |
logger.info(f"Text-to-speech requested for text in {target_lang}") | |
request_id = str(uuid.uuid4()) | |
target_code = LANGUAGE_MAPPING[target_lang] | |
is_inappropriate = check_inappropriate_content(text) | |
if is_inappropriate: | |
logger.warning("Inappropriate content detected in text-to-speech request") | |
output_audio_url = None | |
if model_status["tts"].startswith("loaded") or load_tts_model_for_language(target_code): | |
try: | |
output_path, error = synthesize_speech(text, target_code) | |
if output_path: | |
output_filename = os.path.basename(output_path) | |
output_audio_url = f"https://jerich-talklasapp2.hf.space/audio_output/{output_filename}" | |
logger.info("TTS conversion completed") | |
else: | |
logger.error(f"TTS conversion failed: {error}") | |
except Exception as e: | |
logger.error(f"Error during TTS conversion: {str(e)}") | |
else: | |
logger.warning("TTS model not loaded and could not be loaded") | |
return { | |
"request_id": request_id, | |
"status": "completed" if output_audio_url else "failed", | |
"message": "TTS completed" if output_audio_url else "TTS failed", | |
"text": text, | |
"output_audio": output_audio_url, | |
"is_inappropriate": is_inappropriate | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
logger.info("Starting Uvicorn server...") | |
uvicorn.run(app, host="0.0.0.0", port=8000, workers=1) |