Spaces:
Paused
Paused
import os | |
import torch | |
import numpy as np | |
import soundfile as sf | |
from fastapi import FastAPI, File, UploadFile, HTTPException | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from transformers import ( | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
VitsModel, | |
AutoProcessor, | |
AutoModelForCTC, | |
WhisperProcessor, | |
WhisperForConditionalGeneration | |
) | |
from typing import Optional, Tuple, Dict, List | |
import base64 | |
import io | |
# Your existing TalklasTranslator class (unchanged) | |
class TalklasTranslator: | |
LANGUAGE_MAPPING = { | |
"English": "eng", | |
"Tagalog": "tgl", | |
"Cebuano": "ceb", | |
"Ilocano": "ilo", | |
"Waray": "war", | |
"Pangasinan": "pag" | |
} | |
NLLB_LANGUAGE_CODES = { | |
"eng": "eng_Latn", | |
"tgl": "tgl_Latn", | |
"ceb": "ceb_Latn", | |
"ilo": "ilo_Latn", | |
"war": "war_Latn", | |
"pag": "pag_Latn" | |
} | |
def __init__( | |
self, | |
source_lang: str = "eng", | |
target_lang: str = "tgl", | |
device: Optional[str] = None | |
): | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
self.source_lang = source_lang | |
self.target_lang = target_lang | |
self.sample_rate = 16000 | |
print(f"Initializing Talklas Translator on {self.device}") | |
self._initialize_stt_model() | |
self._initialize_mt_model() | |
self._initialize_tts_model() | |
def _initialize_stt_model(self): | |
try: | |
print("Loading STT model...") | |
try: | |
self.stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") | |
self.stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all") | |
if self.source_lang in self.stt_processor.tokenizer.vocab.keys(): | |
self.stt_processor.tokenizer.set_target_lang(self.source_lang) | |
self.stt_model.load_adapter(self.source_lang) | |
print(f"Loaded MMS STT model for {self.source_lang}") | |
else: | |
print(f"Language {self.source_lang} not in MMS, using default") | |
except Exception as mms_error: | |
print(f"MMS loading failed: {mms_error}") | |
print("Loading Whisper as fallback...") | |
self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
print("Loaded Whisper STT model") | |
self.stt_model.to(self.device) | |
except Exception as e: | |
print(f"STT model initialization failed: {e}") | |
raise RuntimeError("Could not initialize STT model") | |
def _initialize_mt_model(self): | |
try: | |
print("Loading NLLB Translation model...") | |
self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
self.mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
self.mt_model.to(self.device) | |
print("NLLB Translation model loaded") | |
except Exception as e: | |
print(f"MT model initialization failed: {e}") | |
raise | |
def _initialize_tts_model(self): | |
try: | |
print("Loading TTS model...") | |
try: | |
self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}") | |
self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}") | |
print(f"Loaded TTS model for {self.target_lang}") | |
except Exception as tts_error: | |
print(f"Target language TTS failed: {tts_error}") | |
print("Falling back to English TTS") | |
self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
self.tts_model.to(self.device) | |
except Exception as e: | |
print(f"TTS model initialization failed: {e}") | |
raise | |
def update_languages(self, source_lang: str, target_lang: str) -> str: | |
if source_lang == self.source_lang and target_lang == self.target_lang: | |
return "Languages already set" | |
self.source_lang = source_lang | |
self.target_lang = target_lang | |
self._initialize_stt_model() | |
self._initialize_tts_model() | |
return f"Languages updated to {source_lang} → {target_lang}" | |
def speech_to_text(self, audio_path: str) -> str: | |
try: | |
waveform, sample_rate = sf.read(audio_path) | |
if sample_rate != 16000: | |
import librosa | |
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) | |
inputs = self.stt_processor( | |
waveform, | |
sampling_rate=16000, | |
return_tensors="pt" | |
).to(self.device) | |
with torch.no_grad(): | |
if isinstance(self.stt_model, WhisperForConditionalGeneration): | |
generated_ids = self.stt_model.generate(**inputs) | |
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
else: | |
logits = self.stt_model(**inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = self.stt_processor.batch_decode(predicted_ids)[0] | |
return transcription | |
except Exception as e: | |
print(f"Speech recognition failed: {e}") | |
raise RuntimeError("Speech recognition failed") | |
def translate_text(self, text: str) -> str: | |
try: | |
source_code = self.NLLB_LANGUAGE_CODES[self.source_lang] | |
target_code = self.NLLB_LANGUAGE_CODES[self.target_lang] | |
self.mt_tokenizer.src_lang = source_code | |
inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
generated_tokens = self.mt_model.generate( | |
**inputs, | |
forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code), | |
max_length=448 | |
) | |
return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
except Exception as e: | |
print(f"Translation failed: {e}") | |
raise RuntimeError("Text translation failed") | |
def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]: | |
try: | |
inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
output = self.tts_model(**inputs) | |
speech = output.waveform.cpu().numpy().squeeze() | |
speech = (speech * 32767).astype(np.int16) | |
return self.tts_model.config.sampling_rate, speech | |
except Exception as e: | |
print(f"Speech synthesis failed: {e}") | |
raise RuntimeError("Speech synthesis failed") | |
def translate_speech(self, audio_path: str) -> Dict: | |
try: | |
source_text = self.speech_to_text(audio_path) | |
translated_text = self.translate_text(source_text) | |
sample_rate, audio = self.text_to_speech(translated_text) | |
return { | |
"source_text": source_text, | |
"translated_text": translated_text, | |
"output_audio": (sample_rate, audio), | |
"performance": "Translation successful" | |
} | |
except Exception as e: | |
return { | |
"source_text": "Error", | |
"translated_text": "Error", | |
"output_audio": (16000, np.zeros(1000, dtype=np.int16)), | |
"performance": f"Error: {str(e)}" | |
} | |
def translate_text_only(self, text: str) -> Dict: | |
try: | |
translated_text = self.translate_text(text) | |
sample_rate, audio = self.text_to_speech(translated_text) | |
return { | |
"source_text": text, | |
"translated_text": translated_text, | |
"output_audio": (sample_rate, audio), | |
"performance": "Translation successful" | |
} | |
except Exception as e: | |
return { | |
"source_text": text, | |
"translated_text": "Error", | |
"output_audio": (16000, np.zeros(1000, dtype=np.int16)), | |
"performance": f"Error: {str(e)}" | |
} | |
class TranslatorSingleton: | |
_instance = None | |
def get_instance(cls): | |
if cls._instance is None: | |
cls._instance = TalklasTranslator() | |
return cls._instance | |
# FastAPI application | |
app = FastAPI(title="Talklas API", description="Speech-to-Speech Translation API") | |
class TranslationRequest(BaseModel): | |
source_lang: str | |
target_lang: str | |
text: Optional[str] = None | |
async def translate_audio(file: UploadFile = File(...), source_lang: str = "English", target_lang: str = "Tagalog"): | |
try: | |
# Validate languages | |
if source_lang not in TalklasTranslator.LANGUAGE_MAPPING or target_lang not in TalklasTranslator.LANGUAGE_MAPPING: | |
raise HTTPException(status_code=400, detail="Invalid language selection") | |
# Save uploaded audio file temporarily | |
audio_path = f"temp_{file.filename}" | |
with open(audio_path, "wb") as f: | |
f.write(await file.read()) | |
# Update languages | |
source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang] | |
target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang] | |
translator = TranslatorSingleton.get_instance() | |
translator.update_languages(source_code, target_code) | |
# Process the audio | |
results = translator.translate_speech(audio_path) | |
# Clean up temporary file | |
os.remove(audio_path) | |
# Convert audio to base64 for response | |
sample_rate, audio = results["output_audio"] | |
buffer = io.BytesIO() | |
sf.write(buffer, audio, sample_rate, format="wav") | |
audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
return JSONResponse(content={ | |
"source_text": results["source_text"], | |
"translated_text": results["translated_text"], | |
"audio_base64": audio_base64, | |
"sample_rate": sample_rate, | |
"performance": results["performance"] | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}") | |
async def translate_text(request: TranslationRequest): | |
try: | |
# Validate input | |
if not request.text: | |
raise HTTPException(status_code=400, detail="Text input is required") | |
if request.source_lang not in TalklasTranslator.LANGUAGE_MAPPING or request.target_lang not in TalklasTranslator.LANGUAGE_MAPPING: | |
raise HTTPException(status_code=400, detail="Invalid language selection") | |
# Update languages | |
source_code = TalklasTranslator.LANGUAGE_MAPPING[request.source_lang] | |
target_code = TalklasTranslator.LANGUAGE_MAPPING[request.target_lang] | |
translator = TranslatorSingleton.get_instance() | |
translator.update_languages(source_code, target_code) | |
# Process the text | |
results = translator.translate_text_only(request.text) | |
# Convert audio to base64 for response | |
sample_rate, audio = results["output_audio"] | |
buffer = io.BytesIO() | |
sf.write(buffer, audio, sample_rate, format="wav") | |
audio_base64 = base64.b64encode(buffer.getvalue()).decode("utf-8") | |
return JSONResponse(content={ | |
"source_text": results["source_text"], | |
"translated_text": results["translated_text"], | |
"audio_base64": audio_base64, | |
"sample_rate": sample_rate, | |
"performance": results["performance"] | |
}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Translation failed: {str(e)}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |