TalklasApp / app.py
Jerich's picture
Expose the Hugging Face Code as an API
224fa8d verified
raw
history blame
12.3 kB
import os
import torch
import numpy as np
import soundfile as sf
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import (
AutoModelForSeq2SeqLM,
AutoTokenizer,
VitsModel,
AutoProcessor,
AutoModelForCTC,
WhisperProcessor,
WhisperForConditionalGeneration
)
from typing import Optional, Tuple, Dict, List
import base64
import io
# Your existing TalklasTranslator class (unchanged)
class TalklasTranslator:
LANGUAGE_MAPPING = {
"English": "eng",
"Tagalog": "tgl",
"Cebuano": "ceb",
"Ilocano": "ilo",
"Waray": "war",
"Pangasinan": "pag"
}
NLLB_LANGUAGE_CODES = {
"eng": "eng_Latn",
"tgl": "tgl_Latn",
"ceb": "ceb_Latn",
"ilo": "ilo_Latn",
"war": "war_Latn",
"pag": "pag_Latn"
}
def __init__(
self,
source_lang: str = "eng",
target_lang: str = "tgl",
device: Optional[str] = None
):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.source_lang = source_lang
self.target_lang = target_lang
self.sample_rate = 16000
print(f"Initializing Talklas Translator on {self.device}")
self._initialize_stt_model()
self._initialize_mt_model()
self._initialize_tts_model()
def _initialize_stt_model(self):
try:
print("Loading STT model...")
try:
self.stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all")
self.stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all")
if self.source_lang in self.stt_processor.tokenizer.vocab.keys():
self.stt_processor.tokenizer.set_target_lang(self.source_lang)
self.stt_model.load_adapter(self.source_lang)
print(f"Loaded MMS STT model for {self.source_lang}")
else:
print(f"Language {self.source_lang} not in MMS, using default")
except Exception as mms_error:
print(f"MMS loading failed: {mms_error}")
print("Loading Whisper as fallback...")
self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
print("Loaded Whisper STT model")
self.stt_model.to(self.device)
except Exception as e:
print(f"STT model initialization failed: {e}")
raise RuntimeError("Could not initialize STT model")
def _initialize_mt_model(self):
try:
print("Loading NLLB Translation model...")
self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
self.mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
self.mt_model.to(self.device)
print("NLLB Translation model loaded")
except Exception as e:
print(f"MT model initialization failed: {e}")
raise
def _initialize_tts_model(self):
try:
print("Loading TTS model...")
try:
self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}")
print(f"Loaded TTS model for {self.target_lang}")
except Exception as tts_error:
print(f"Target language TTS failed: {tts_error}")
print("Falling back to English TTS")
self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng")
self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng")
self.tts_model.to(self.device)
except Exception as e:
print(f"TTS model initialization failed: {e}")
raise
def update_languages(self, source_lang: str, target_lang: str) -> str:
if source_lang == self.source_lang and target_lang == self.target_lang:
return "Languages already set"
self.source_lang = source_lang
self.target_lang = target_lang
self._initialize_stt_model()
self._initialize_tts_model()
return f"Languages updated to {source_lang}{target_lang}"
def speech_to_text(self, audio_path: str) -> str:
try:
waveform, sample_rate = sf.read(audio_path)
if sample_rate != 16000:
import librosa
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
inputs = self.stt_processor(
waveform,
sampling_rate=16000,
return_tensors="pt"
).to(self.device)
with torch.no_grad():
if isinstance(self.stt_model, WhisperForConditionalGeneration):
generated_ids = self.stt_model.generate(**inputs)
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
else:
logits = self.stt_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = self.stt_processor.batch_decode(predicted_ids)[0]
return transcription
except Exception as e:
print(f"Speech recognition failed: {e}")
raise RuntimeError("Speech recognition failed")
def translate_text(self, text: str) -> str:
try:
source_code = self.NLLB_LANGUAGE_CODES[self.source_lang]
target_code = self.NLLB_LANGUAGE_CODES[self.target_lang]
self.mt_tokenizer.src_lang = source_code
inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device)
with torch.no_grad():
generated_tokens = self.mt_model.generate(
**inputs,
forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code),
max_length=448
)
return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
except Exception as e:
print(f"Translation failed: {e}")
raise RuntimeError("Text translation failed")
def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]:
try:
inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device)
with torch.no_grad():
output = self.tts_model(**inputs)
speech = output.waveform.cpu().numpy().squeeze()
speech = (speech * 32767).astype(np.int16)
return self.tts_model.config.sampling_rate, speech
except Exception as e:
print(f"Speech synthesis failed: {e}")
raise RuntimeError("Speech synthesis failed")
def translate_speech(self, audio_path: str) -> Dict:
try:
source_text = self.speech_to_text(audio_path)
translated_text = self.translate_text(source_text)
sample_rate, audio = self.text_to_speech(translated_text)
return {
"source_text": source_text,
"translated_text": translated_text,
"output_audio": (sample_rate, audio),
"performance": "Translation successful"
}
except Exception as e:
return {
"source_text": "Error",
"translated_text": "Error",
"output_audio": (16000, np.zeros(1000, dtype=np.int16)),
"performance": f"Error: {str(e)}"
}
def translate_text_only(self, text: str) -> Dict:
try:
translated_text = self.translate_text(text)
sample_rate, audio = self.text_to_speech(translated_text)
return {
"source_text": text,
"translated_text": translated_text,
"output_audio": (sample_rate, audio),
"performance": "Translation successful"
}
except Exception as e:
return {
"source_text": text,
"translated_text": "Error",
"output_audio": (16000, np.zeros(1000, dtype=np.int16)),
"performance": f"Error: {str(e)}"
}
class TranslatorSingleton:
_instance = None
@classmethod
def get_instance(cls):
if cls._instance is None:
cls._instance = TalklasTranslator()
return cls._instance
# FastAPI application
app = FastAPI(title="Talklas API", description="Speech-to-Speech Translation API")
class TranslationRequest(BaseModel):
source_lang: str
target_lang: str
text: Optional[str] = None
@app.post("/translate/audio")
async def translate_audio(file: UploadFile = File(...), source_lang: str = "English", target_lang: str = "Tagalog"):
try:
# Validate languages
if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
raise HTTPException(status_code=400, detail="Invalid language selection")
# Save uploaded audio file temporarily
audio_path = f"temp_{file.filename}"
with open(audio_path, "wb") as f:
f.write(await file.read())
# Update languages
source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang]
target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang]
translator = TranslatorSingleton.get_instance()
translator.update_languages(source_code, target_code)
# Process the audio
results = translator.translate_speech(audio_path)
# Clean up temporary file
os.remove(audio_path)
# Convert audio to base64 for response
sample_rate, audio = results["output_audio"]
buffer = io.BytesIO()
sf.write(buffer, audio, sample_rate, format="wav")
audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return JSONResponse(content={
"source_text": results["source_text"],
"translated_text": results["translated_text"],
"audio_base64": audio_base64,
"sample_rate": sample_rate,
"performance": results["performance"]
})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
@app.post("/translate/text")
async def translate_text(request: TranslationRequest):
try:
# Validate input
if not request.text:
raise HTTPException(status_code=400, detail="Text input is required")
if request.source_lang not in TalklasTranslator.LANGUAGE_MAPPING or request.target_lang not in TalklasTranslator.LANGUAGE_MAPPING:
raise HTTPException(status_code=400, detail="Invalid language selection")
# Update languages
source_code = TalklasTranslator.LANGUAGE_MAPPING[request.source_lang]
target_code = TalklasTranslator.LANGUAGE_MAPPING[request.target_lang]
translator = TranslatorSingleton.get_instance()
translator.update_languages(source_code, target_code)
# Process the text
results = translator.translate_text_only(request.text)
# Convert audio to base64 for response
sample_rate, audio = results["output_audio"]
buffer = io.BytesIO()
sf.write(buffer, audio, sample_rate, format="wav")
audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8")
return JSONResponse(content={
"source_text": results["source_text"],
"translated_text": results["translated_text"],
"audio_base64": audio_base64,
"sample_rate": sample_rate,
"performance": results["performance"]
})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)