Update app.py
Browse files
app.py
CHANGED
@@ -1,13 +1,225 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
librosa
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
datasets
|
11 |
-
|
12 |
-
|
13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import librosa
|
5 |
+
import soundfile as sf
|
6 |
+
import tempfile
|
7 |
+
import os
|
8 |
+
|
9 |
+
from transformers import pipeline, VitsModel, AutoTokenizer
|
10 |
+
from datasets import load_dataset
|
11 |
+
|
12 |
+
# For MeloTTS (Chinese and Japanese)
|
13 |
+
try:
|
14 |
+
from melo.api import TTS as MeloTTS
|
15 |
+
except ImportError:
|
16 |
+
raise ImportError("Please install the MeloTTS package (e.g., pip install myshell-ai/MeloTTS-Chinese)")
|
17 |
+
|
18 |
+
# ------------------------------------------------------
|
19 |
+
# 1. ASR Pipeline (English) using Wav2Vec2
|
20 |
+
# ------------------------------------------------------
|
21 |
+
asr = pipeline(
|
22 |
+
"automatic-speech-recognition",
|
23 |
+
model="facebook/wav2vec2-base-960h"
|
24 |
+
)
|
25 |
+
|
26 |
+
# ------------------------------------------------------
|
27 |
+
# 2. Translation Models (8 languages)
|
28 |
+
# ------------------------------------------------------
|
29 |
+
translation_models = {
|
30 |
+
"Spanish": "Helsinki-NLP/opus-mt-en-es",
|
31 |
+
"Vietnamese": "Helsinki-NLP/opus-mt-en-vi",
|
32 |
+
"Indonesian": "Helsinki-NLP/opus-mt-en-id",
|
33 |
+
"Turkish": "Helsinki-NLP/opus-mt-en-trk",
|
34 |
+
"Portuguese": "Helsinki-NLP/opus-mt-tc-big-en-pt",
|
35 |
+
"Korean": "Helsinki-NLP/opus-mt-tc-big-en-ko",
|
36 |
+
"Chinese": "Helsinki-NLP/opus-mt-en-zh",
|
37 |
+
"Japanese": "Helsinki-NLP/opus-mt-en-jap"
|
38 |
+
}
|
39 |
+
|
40 |
+
translation_tasks = {
|
41 |
+
"Spanish": "translation_en_to_es",
|
42 |
+
"Vietnamese": "translation_en_to_vi",
|
43 |
+
"Indonesian": "translation_en_to_id",
|
44 |
+
"Turkish": "translation_en_to_tr",
|
45 |
+
"Portuguese": "translation_en_to_pt",
|
46 |
+
"Korean": "translation_en_to-ko",
|
47 |
+
"Chinese": "translation_en_to_zh",
|
48 |
+
"Japanese": "translation_en_to_ja"
|
49 |
+
}
|
50 |
+
|
51 |
+
# ------------------------------------------------------
|
52 |
+
# 3. TTS Configuration
|
53 |
+
# - MMS TTS (VITS) for: Spanish, Vietnamese, Indonesian, Turkish, Portuguese, Korean
|
54 |
+
# - MeloTTS for: Chinese and Japanese
|
55 |
+
# ------------------------------------------------------
|
56 |
+
tts_config = {
|
57 |
+
"Spanish": {"model_id": "facebook/mms-tts-spa", "architecture": "vits", "type": "mms"},
|
58 |
+
"Vietnamese": {"model_id": "facebook/mms-tts-vie", "architecture": "vits", "type": "mms"},
|
59 |
+
"Indonesian": {"model_id": "facebook/mms-tts-ind", "architecture": "vits", "type": "mms"},
|
60 |
+
"Turkish": {"model_id": "facebook/mms-tts-tur", "architecture": "vits", "type": "mms"},
|
61 |
+
"Portuguese": {"model_id": "facebook/mms-tts-por", "architecture": "vits", "type": "mms"},
|
62 |
+
"Korean": {"model_id": "facebook/mms-tts-kor", "architecture": "vits", "type": "mms"},
|
63 |
+
"Chinese": {"type": "melo"},
|
64 |
+
"Japanese": {"type": "melo"}
|
65 |
+
}
|
66 |
+
|
67 |
+
# ------------------------------------------------------
|
68 |
+
# 4. Global Caches for Translators and TTS Models
|
69 |
+
# ------------------------------------------------------
|
70 |
+
translator_cache = {}
|
71 |
+
mms_tts_cache = {} # For MMS (VITS-based) TTS models
|
72 |
+
melo_tts_cache = {} # For MeloTTS models (Chinese/Japanese)
|
73 |
+
|
74 |
+
# ------------------------------------------------------
|
75 |
+
# 5. Translator Helper
|
76 |
+
# ------------------------------------------------------
|
77 |
+
def get_translator(lang):
|
78 |
+
if lang in translator_cache:
|
79 |
+
return translator_cache[lang]
|
80 |
+
model_name = translation_models[lang]
|
81 |
+
task_name = translation_tasks[lang]
|
82 |
+
translator = pipeline(task_name, model=model_name)
|
83 |
+
translator_cache[lang] = translator
|
84 |
+
return translator
|
85 |
+
|
86 |
+
# ------------------------------------------------------
|
87 |
+
# 6. MMS TTS (VITS) Helper for languages using MMS TTS
|
88 |
+
# ------------------------------------------------------
|
89 |
+
def load_mms_tts(lang):
|
90 |
+
if lang in mms_tts_cache:
|
91 |
+
return mms_tts_cache[lang]
|
92 |
+
config = tts_config[lang]
|
93 |
+
try:
|
94 |
+
model = VitsModel.from_pretrained(config["model_id"])
|
95 |
+
tokenizer = AutoTokenizer.from_pretrained(config["model_id"])
|
96 |
+
mms_tts_cache[lang] = (model, tokenizer)
|
97 |
+
except Exception as e:
|
98 |
+
raise RuntimeError(f"Failed to load MMS TTS model for {lang} ({config['model_id']}): {e}")
|
99 |
+
return mms_tts_cache[lang]
|
100 |
+
|
101 |
+
def run_mms_tts(text, lang):
|
102 |
+
model, tokenizer = load_mms_tts(lang)
|
103 |
+
inputs = tokenizer(text, return_tensors="pt")
|
104 |
+
with torch.no_grad():
|
105 |
+
output = model(**inputs)
|
106 |
+
if not hasattr(output, "waveform"):
|
107 |
+
raise RuntimeError(f"MMS TTS model output for {lang} does not contain 'waveform'.")
|
108 |
+
waveform = output.waveform.squeeze().cpu().numpy()
|
109 |
+
sample_rate = 16000
|
110 |
+
return sample_rate, waveform
|
111 |
+
|
112 |
+
# ------------------------------------------------------
|
113 |
+
# 7. MeloTTS Helper for Chinese and Japanese
|
114 |
+
# ------------------------------------------------------
|
115 |
+
def run_melo_tts(text, lang):
|
116 |
+
"""
|
117 |
+
Uses the myshell-ai MeloTTS model.
|
118 |
+
For Chinese, use language parameter 'ZH'; for Japanese, use 'JP'.
|
119 |
+
"""
|
120 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
121 |
+
lang_param = 'ZH' if lang == "Chinese" else 'JP'
|
122 |
+
if lang not in melo_tts_cache:
|
123 |
+
try:
|
124 |
+
model = MeloTTS(language=lang_param, device=device)
|
125 |
+
melo_tts_cache[lang] = model
|
126 |
+
except Exception as e:
|
127 |
+
raise RuntimeError(f"Failed to load MeloTTS model for {lang}: {e}")
|
128 |
+
else:
|
129 |
+
model = melo_tts_cache[lang]
|
130 |
+
speaker_ids = model.hps.data.spk2id
|
131 |
+
# Assume the speaker key is the same as lang_param
|
132 |
+
speaker_key = lang_param
|
133 |
+
speed = 1.0
|
134 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
|
135 |
+
tmp_name = tmp.name
|
136 |
+
try:
|
137 |
+
model.tts_to_file(text, speaker_ids[speaker_key], tmp_name, speed=speed)
|
138 |
+
data, sr = sf.read(tmp_name)
|
139 |
+
finally:
|
140 |
+
if os.path.exists(tmp_name):
|
141 |
+
os.remove(tmp_name)
|
142 |
+
return sr, data
|
143 |
+
|
144 |
+
# ------------------------------------------------------
|
145 |
+
# 8. Main Prediction Function
|
146 |
+
# ------------------------------------------------------
|
147 |
+
def predict(audio, text, target_language):
|
148 |
+
"""
|
149 |
+
1. Obtain English text (via ASR if audio provided, else text).
|
150 |
+
2. Translate the English text to target_language.
|
151 |
+
3. Generate TTS audio using either MMS TTS (VITS) or MeloTTS.
|
152 |
+
"""
|
153 |
+
# Step 1: Get English text.
|
154 |
+
if text.strip():
|
155 |
+
english_text = text.strip()
|
156 |
+
elif audio is not None:
|
157 |
+
sample_rate, audio_data = audio
|
158 |
+
if audio_data.dtype not in [np.float32, np.float64]:
|
159 |
+
audio_data = audio_data.astype(np.float32)
|
160 |
+
if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
|
161 |
+
audio_data = np.mean(audio_data, axis=1)
|
162 |
+
if sample_rate != 16000:
|
163 |
+
audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
|
164 |
+
asr_input = {"array": audio_data, "sampling_rate": 16000}
|
165 |
+
asr_result = asr(asr_input)
|
166 |
+
english_text = asr_result["text"]
|
167 |
+
else:
|
168 |
+
return "No input provided.", "", None
|
169 |
+
|
170 |
+
# Step 2: Translate.
|
171 |
+
translator = get_translator(target_language)
|
172 |
+
try:
|
173 |
+
translation_result = translator(english_text)
|
174 |
+
translated_text = translation_result[0]["translation_text"]
|
175 |
+
except Exception as e:
|
176 |
+
return english_text, f"Translation error: {e}", None
|
177 |
+
|
178 |
+
# Step 3: TTS.
|
179 |
+
try:
|
180 |
+
tts_type = tts_config[target_language]["type"]
|
181 |
+
if tts_type == "mms":
|
182 |
+
sr, waveform = run_mms_tts(translated_text, target_language)
|
183 |
+
elif tts_type == "melo":
|
184 |
+
sr, waveform = run_melo_tts(translated_text, target_language)
|
185 |
+
else:
|
186 |
+
raise RuntimeError("Unknown TTS type for target language.")
|
187 |
+
except Exception as e:
|
188 |
+
return english_text, translated_text, f"TTS error: {e}"
|
189 |
+
|
190 |
+
return english_text, translated_text, (sr, waveform)
|
191 |
+
|
192 |
+
# ------------------------------------------------------
|
193 |
+
# 9. Gradio Interface
|
194 |
+
# ------------------------------------------------------
|
195 |
+
language_choices = [
|
196 |
+
"Spanish", "Vietnamese", "Indonesian", "Turkish", "Portuguese", "Korean", "Chinese", "Japanese"
|
197 |
+
]
|
198 |
+
|
199 |
+
iface = gr.Interface(
|
200 |
+
fn=predict,
|
201 |
+
inputs=[
|
202 |
+
gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
|
203 |
+
gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
|
204 |
+
gr.Dropdown(choices=language_choices, value="Spanish", label="Target Language")
|
205 |
+
],
|
206 |
+
outputs=[
|
207 |
+
gr.Textbox(label="English Transcription"),
|
208 |
+
gr.Textbox(label="Translation (Target Language)"),
|
209 |
+
gr.Audio(label="Synthesized Speech")
|
210 |
+
],
|
211 |
+
title="Multimodal Language Learning Aid",
|
212 |
+
description=(
|
213 |
+
"This app performs the following steps:\n"
|
214 |
+
"1. Transcribes English speech using Wav2Vec2 (or accepts text input).\n"
|
215 |
+
"2. Translates the English text to the target language using Helsinki-NLP MarianMT models.\n"
|
216 |
+
"3. Synthesizes speech:\n"
|
217 |
+
" - For Spanish, Vietnamese, Indonesian, Turkish, Portuguese, and Korean: uses Facebook MMS TTS (VITS-based).\n"
|
218 |
+
" - For Chinese and Japanese: uses myshell-ai MeloTTS models.\n"
|
219 |
+
"\nSelect your target language from the dropdown."
|
220 |
+
),
|
221 |
+
allow_flagging="never"
|
222 |
+
)
|
223 |
+
|
224 |
+
if __name__ == "__main__":
|
225 |
+
iface.launch(server_name="0.0.0.0", server_port=7860)
|