Spaces:
Paused
Paused
import os | |
import torch | |
import gradio as gr | |
import numpy as np | |
import soundfile as sf | |
from transformers import ( | |
AutoModelForSeq2SeqLM, | |
AutoTokenizer, | |
VitsModel, | |
AutoProcessor, | |
AutoModelForCTC, | |
WhisperProcessor, | |
WhisperForConditionalGeneration | |
) | |
from typing import Optional, Tuple, Dict, List | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
import base64 | |
import io | |
import tempfile | |
class TalklasTranslator: | |
""" | |
Speech-to-Speech translation pipeline for Philippine languages. | |
Uses MMS/Whisper for STT, NLLB for MT, and MMS for TTS. | |
""" | |
LANGUAGE_MAPPING = { | |
"English": "eng", | |
"Tagalog": "tgl", | |
"Cebuano": "ceb", | |
"Ilocano": "ilo", | |
"Waray": "war", | |
"Pangasinan": "pag" | |
} | |
NLLB_LANGUAGE_CODES = { | |
"eng": "eng_Latn", | |
"tgl": "tgl_Latn", | |
"ceb": "ceb_Latn", | |
"ilo": "ilo_Latn", | |
"war": "war_Latn", | |
"pag": "pag_Latn" | |
} | |
def __init__( | |
self, | |
source_lang: str = "eng", | |
target_lang: str = "tgl", | |
device: Optional[str] = None | |
): | |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu") | |
self.source_lang = source_lang | |
self.target_lang = target_lang | |
self.sample_rate = 16000 | |
print(f"Initializing Talklas Translator on {self.device}") | |
# Initialize models | |
self._initialize_stt_model() | |
self._initialize_mt_model() | |
self._initialize_tts_model() | |
def _initialize_stt_model(self): | |
"""Initialize speech-to-text model with fallback to Whisper""" | |
try: | |
print("Loading STT model...") | |
try: | |
# Try loading MMS model first | |
self.stt_processor = AutoProcessor.from_pretrained("facebook/mms-1b-all") | |
self.stt_model = AutoModelForCTC.from_pretrained("facebook/mms-1b-all") | |
# Set language if available | |
if self.source_lang in self.stt_processor.tokenizer.vocab.keys(): | |
self.stt_processor.tokenizer.set_target_lang(self.source_lang) | |
self.stt_model.load_adapter(self.source_lang) | |
print(f"Loaded MMS STT model for {self.source_lang}") | |
else: | |
print(f"Language {self.source_lang} not in MMS, using default") | |
except Exception as mms_error: | |
print(f"MMS loading failed: {mms_error}") | |
# Fallback to Whisper | |
print("Loading Whisper as fallback...") | |
self.stt_processor = WhisperProcessor.from_pretrained("openai/whisper-small") | |
self.stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small") | |
print("Loaded Whisper STT model") | |
self.stt_model.to(self.device) | |
except Exception as e: | |
print(f"STT model initialization failed: {e}") | |
raise RuntimeError("Could not initialize STT model") | |
def _initialize_mt_model(self): | |
"""Initialize machine translation model""" | |
try: | |
print("Loading NLLB Translation model...") | |
self.mt_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M") | |
self.mt_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M") | |
self.mt_model.to(self.device) | |
print("NLLB Translation model loaded") | |
except Exception as e: | |
print(f"MT model initialization failed: {e}") | |
raise | |
def _initialize_tts_model(self): | |
"""Initialize text-to-speech model""" | |
try: | |
print("Loading TTS model...") | |
try: | |
self.tts_model = VitsModel.from_pretrained(f"facebook/mms-tts-{self.target_lang}") | |
self.tts_tokenizer = AutoTokenizer.from_pretrained(f"facebook/mms-tts-{self.target_lang}") | |
print(f"Loaded TTS model for {self.target_lang}") | |
except Exception as tts_error: | |
print(f"Target language TTS failed: {tts_error}") | |
print("Falling back to English TTS") | |
self.tts_model = VitsModel.from_pretrained("facebook/mms-tts-eng") | |
self.tts_tokenizer = AutoTokenizer.from_pretrained("facebook/mms-tts-eng") | |
self.tts_model.to(self.device) | |
except Exception as e: | |
print(f"TTS model initialization failed: {e}") | |
raise | |
def update_languages(self, source_lang: str, target_lang: str) -> str: | |
"""Update languages and reinitialize models if needed""" | |
if source_lang == self.source_lang and target_lang == self.target_lang: | |
return "Languages already set" | |
self.source_lang = source_lang | |
self.target_lang = target_lang | |
# Only reinitialize models that depend on language | |
self._initialize_stt_model() | |
self._initialize_tts_model() | |
return f"Languages updated to {source_lang} → {target_lang}" | |
def speech_to_text(self, audio_path: str) -> str: | |
"""Convert speech to text using loaded STT model""" | |
try: | |
waveform, sample_rate = sf.read(audio_path) | |
if sample_rate != 16000: | |
import librosa | |
waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000) | |
inputs = self.stt_processor( | |
waveform, | |
sampling_rate=16000, | |
return_tensors="pt" | |
).to(self.device) | |
with torch.no_grad(): | |
if isinstance(self.stt_model, WhisperForConditionalGeneration): # Whisper model | |
generated_ids = self.stt_model.generate(**inputs) | |
transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
else: # MMS model (Wav2Vec2ForCTC) | |
logits = self.stt_model(**inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = self.stt_processor.batch_decode(predicted_ids)[0] | |
return transcription | |
except Exception as e: | |
print(f"Speech recognition failed: {e}") | |
raise RuntimeError("Speech recognition failed") | |
def translate_text(self, text: str) -> str: | |
"""Translate text using NLLB model""" | |
try: | |
source_code = self.NLLB_LANGUAGE_CODES[self.source_lang] | |
target_code = self.NLLB_LANGUAGE_CODES[self.target_lang] | |
self.mt_tokenizer.src_lang = source_code | |
inputs = self.mt_tokenizer(text, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
generated_tokens = self.mt_model.generate( | |
**inputs, | |
forced_bos_token_id=self.mt_tokenizer.convert_tokens_to_ids(target_code), | |
max_length=448 | |
) | |
return self.mt_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] | |
except Exception as e: | |
print(f"Translation failed: {e}") | |
raise RuntimeError("Text translation failed") | |
def text_to_speech(self, text: str) -> Tuple[int, np.ndarray]: | |
"""Convert text to speech""" | |
try: | |
inputs = self.tts_tokenizer(text, return_tensors="pt").to(self.device) | |
with torch.no_grad(): | |
output = self.tts_model(**inputs) | |
speech = output.waveform.cpu().numpy().squeeze() | |
speech = (speech * 32767).astype(np.int16) | |
return self.tts_model.config.sampling_rate, speech | |
except Exception as e: | |
print(f"Speech synthesis failed: {e}") | |
raise RuntimeError("Speech synthesis failed") | |
def translate_speech(self, audio_path: str) -> Dict: | |
"""Full speech-to-speech translation""" | |
try: | |
source_text = self.speech_to_text(audio_path) | |
translated_text = self.translate_text(source_text) | |
sample_rate, audio = self.text_to_speech(translated_text) | |
return { | |
"source_text": source_text, | |
"translated_text": translated_text, | |
"output_audio": (sample_rate, audio), | |
"performance": "Translation successful" | |
} | |
except Exception as e: | |
return { | |
"source_text": "Error", | |
"translated_text": "Error", | |
"output_audio": (16000, np.zeros(1000, dtype=np.int16)), | |
"performance": f"Error: {str(e)}" | |
} | |
def translate_text_only(self, text: str) -> Dict: | |
"""Text-to-speech translation""" | |
try: | |
translated_text = self.translate_text(text) | |
sample_rate, audio = self.text_to_speech(translated_text) | |
return { | |
"source_text": text, | |
"translated_text": translated_text, | |
"output_audio": (sample_rate, audio), | |
"performance": "Translation successful" | |
} | |
except Exception as e: | |
return { | |
"source_text": text, | |
"translated_text": "Error", | |
"output_audio": (16000, np.zeros(1000, dtype=np.int16)), | |
"performance": f"Error: {str(e)}" | |
} | |
class TranslatorSingleton: | |
_instance = None | |
def get_instance(cls): | |
if cls._instance is None: | |
cls._instance = TalklasTranslator() | |
return cls._instance | |
def process_audio(audio_path, source_lang, target_lang): | |
"""Process audio through the full translation pipeline""" | |
# Validate input | |
if not audio_path: | |
return None, "No audio provided", "No translation available", "Please provide audio input" | |
# Update languages | |
source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang] | |
target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang] | |
translator = TranslatorSingleton.get_instance() | |
status = translator.update_languages(source_code, target_code) | |
# Process the audio | |
results = translator.translate_speech(audio_path) | |
return results["output_audio"], results["source_text"], results["translated_text"], results["performance"] | |
def process_text(text, source_lang, target_lang): | |
"""Process text through the translation pipeline""" | |
# Validate input | |
if not text: | |
return None, "No text provided", "No translation available", "Please provide text input" | |
# Update languages | |
source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang] | |
target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang] | |
translator = TranslatorSingleton.get_instance() | |
status = translator.update_languages(source_code, target_code) | |
# Process the text | |
results = translator.translate_text_only(text) | |
return results["output_audio"], results["source_text"], results["translated_text"], results["performance"] | |
def create_gradio_interface(): | |
"""Create and launch Gradio interface""" | |
# Define language options | |
languages = list(TalklasTranslator.LANGUAGE_MAPPING.keys()) | |
# Define the interface | |
demo = gr.Blocks(title="Talklas - Speech & Text Translation") | |
with demo: | |
gr.Markdown("# Talklas: Speech-to-Speech Translation System") | |
gr.Markdown("### Translate between Philippine Languages and English") | |
with gr.Row(): | |
with gr.Column(): | |
source_lang = gr.Dropdown( | |
choices=languages, | |
value="English", | |
label="Source Language" | |
) | |
target_lang = gr.Dropdown( | |
choices=languages, | |
value="Tagalog", | |
label="Target Language" | |
) | |
language_status = gr.Textbox(label="Language Status") | |
update_btn = gr.Button("Update Languages") | |
with gr.Tabs(): | |
with gr.TabItem("Audio Input"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Audio Input") | |
audio_input = gr.Audio( | |
type="filepath", | |
label="Upload Audio File" | |
) | |
audio_translate_btn = gr.Button("Translate Audio", variant="primary") | |
with gr.Column(): | |
gr.Markdown("### Output") | |
audio_output = gr.Audio( | |
label="Translated Speech", | |
type="numpy", | |
autoplay=True | |
) | |
with gr.TabItem("Text Input"): | |
with gr.Row(): | |
with gr.Column(): | |
gr.Markdown("### Text Input") | |
text_input = gr.Textbox( | |
label="Enter text to translate", | |
lines=3 | |
) | |
text_translate_btn = gr.Button("Translate Text", variant="primary") | |
with gr.Column(): | |
gr.Markdown("### Output") | |
text_output = gr.Audio( | |
label="Translated Speech", | |
type="numpy", | |
autoplay=True | |
) | |
with gr.Row(): | |
with gr.Column(): | |
source_text = gr.Textbox(label="Source Text") | |
translated_text = gr.Textbox(label="Translated Text") | |
performance_info = gr.Textbox(label="Performance Metrics") | |
# Set up events | |
update_btn.click( | |
lambda source_lang, target_lang: TranslatorSingleton.get_instance().update_languages( | |
TalklasTranslator.LANGUAGE_MAPPING[source_lang], | |
TalklasTranslator.LANGUAGE_MAPPING[target_lang] | |
), | |
inputs=[source_lang, target_lang], | |
outputs=[language_status] | |
) | |
# Audio translate button click | |
audio_translate_btn.click( | |
process_audio, | |
inputs=[audio_input, source_lang, target_lang], | |
outputs=[audio_output, source_text, translated_text, performance_info] | |
).then( | |
None, | |
None, | |
None, | |
js="""() => { | |
const audioElements = document.querySelectorAll('audio'); | |
if (audioElements.length > 0) { | |
const lastAudio = audioElements[audioElements.length - 1]; | |
lastAudio.play().catch(error => { | |
console.warn('Autoplay failed:', error); | |
alert('Audio may require user interaction to play'); | |
}); | |
} | |
}""" | |
) | |
# Text translate button click | |
text_translate_btn.click( | |
process_text, | |
inputs=[text_input, source_lang, target_lang], | |
outputs=[text_output, source_text, translated_text, performance_info] | |
).then( | |
None, | |
None, | |
None, | |
js="""() => { | |
const audioElements = document.querySelectorAll('audio'); | |
if (audioElements.length > 0) { | |
const lastAudio = audioElements[audioElements.length - 1]; | |
lastAudio.play().catch(error => { | |
console.warn('Autoplay failed:', error); | |
alert('Audio may require user interaction to play'); | |
}); | |
} | |
}""" | |
) | |
return demo | |
# Create Flask app | |
app = Flask(__name__) | |
CORS(app) # This allows cross-origin requests | |
# Initialize the translator singleton | |
translator_instance = None | |
def get_translator(): | |
global translator_instance | |
if translator_instance is None: | |
translator_instance = TalklasTranslator() | |
return translator_instance | |
def api_translate_speech(): | |
"""API endpoint for speech-to-speech translation""" | |
try: | |
# Check if required data is in the request | |
if 'audio' not in request.files: | |
return jsonify({ | |
"error": "No audio file provided" | |
}), 400 | |
audio_file = request.files['audio'] | |
source_lang = request.form.get('source_lang', 'English') | |
target_lang = request.form.get('target_lang', 'Tagalog') | |
# Save temporary audio file | |
with tempfile.NamedTemporaryFile(delete=False, suffix='.wav') as temp_audio: | |
audio_file.save(temp_audio.name) | |
temp_audio_path = temp_audio.name | |
# Get translator and update languages | |
translator = get_translator() | |
source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang] | |
target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang] | |
translator.update_languages(source_code, target_code) | |
# Process the audio | |
results = translator.translate_speech(temp_audio_path) | |
# Convert audio to base64 for transmission | |
sample_rate, audio_data = results["output_audio"] | |
audio_bytes = io.BytesIO() | |
sf.write(audio_bytes, audio_data, sample_rate, format='WAV') | |
audio_base64 = base64.b64encode(audio_bytes.getvalue()).decode('utf-8') | |
# Clean up temporary file | |
os.unlink(temp_audio_path) | |
return jsonify({ | |
"source_text": results["source_text"], | |
"translated_text": results["translated_text"], | |
"audio_base64": audio_base64, | |
"sample_rate": sample_rate, | |
"status": "success" | |
}) | |
except Exception as e: | |
return jsonify({ | |
"error": str(e), | |
"status": "error" | |
}), 500 | |
def api_translate_text(): | |
"""API endpoint for text-to-speech translation""" | |
try: | |
data = request.json | |
if not data or 'text' not in data: | |
return jsonify({ | |
"error": "No text provided" | |
}), 400 | |
text = data['text'] | |
source_lang = data.get('source_lang', 'English') | |
target_lang = data.get('target_lang', 'Tagalog') | |
# Get translator and update languages | |
translator = get_translator() | |
source_code = TalklasTranslator.LANGUAGE_MAPPING[source_lang] | |
target_code = TalklasTranslator.LANGUAGE_MAPPING[target_lang] | |
translator.update_languages(source_code, target_code) | |
# Process the text | |
results = translator.translate_text_only(text) | |
# Convert audio to base64 for transmission | |
sample_rate, audio_data = results["output_audio"] | |
audio_bytes = io.BytesIO() | |
sf.write(audio_bytes, audio_data, sample_rate, format='WAV') | |
audio_base64 = base64.b64encode(audio_bytes.getvalue()).decode('utf-8') | |
return jsonify({ | |
"source_text": results["source_text"], | |
"translated_text": results["translated_text"], | |
"audio_base64": audio_base64, | |
"sample_rate": sample_rate, | |
"status": "success" | |
}) | |
except Exception as e: | |
return jsonify({ | |
"error": str(e), | |
"status": "error" | |
}), 500 | |
def get_languages(): | |
"""Return available languages""" | |
return jsonify({ | |
"languages": list(TalklasTranslator.LANGUAGE_MAPPING.keys()) | |
}) | |
# Keep the Gradio interface for users who directly access the Hugging Face space | |
def create_gradio_interface(): | |
# Your existing Gradio interface code | |
# ... | |
# Run both the API server and Gradio | |
if __name__ == "__main__": | |
# Launch Gradio in a separate thread | |
import threading | |
demo = create_gradio_interface() | |
threading.Thread(target=demo.launch, kwargs={"share": True, "debug": False}).start() | |
# Run the Flask server | |
app.run(host='0.0.0.0', port=7860) |