Yilin0601 commited on
Commit
824ebbf
·
verified ·
1 Parent(s): 4f50eab

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -13
app.py CHANGED
@@ -1,13 +1,225 @@
1
- torch
2
- transformers>=4.33.0
3
- gradio
4
- librosa
5
- numpy
6
- scipy
7
- accelerate
8
- sentencepiece
9
- soundfile
10
- datasets
11
- TTS
12
- git+https://github.com/myshell-ai/MeloTTS-Chinese.git
13
- git+https://github.com/myshell-ai/MeloTTS-Japanese.git
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import numpy as np
4
+ import librosa
5
+ import soundfile as sf
6
+ import tempfile
7
+ import os
8
+
9
+ from transformers import pipeline, VitsModel, AutoTokenizer
10
+ from datasets import load_dataset
11
+
12
+ # For MeloTTS (Chinese and Japanese)
13
+ try:
14
+ from melo.api import TTS as MeloTTS
15
+ except ImportError:
16
+ raise ImportError("Please install the MeloTTS package (e.g., pip install myshell-ai/MeloTTS-Chinese)")
17
+
18
+ # ------------------------------------------------------
19
+ # 1. ASR Pipeline (English) using Wav2Vec2
20
+ # ------------------------------------------------------
21
+ asr = pipeline(
22
+ "automatic-speech-recognition",
23
+ model="facebook/wav2vec2-base-960h"
24
+ )
25
+
26
+ # ------------------------------------------------------
27
+ # 2. Translation Models (8 languages)
28
+ # ------------------------------------------------------
29
+ translation_models = {
30
+ "Spanish": "Helsinki-NLP/opus-mt-en-es",
31
+ "Vietnamese": "Helsinki-NLP/opus-mt-en-vi",
32
+ "Indonesian": "Helsinki-NLP/opus-mt-en-id",
33
+ "Turkish": "Helsinki-NLP/opus-mt-en-trk",
34
+ "Portuguese": "Helsinki-NLP/opus-mt-tc-big-en-pt",
35
+ "Korean": "Helsinki-NLP/opus-mt-tc-big-en-ko",
36
+ "Chinese": "Helsinki-NLP/opus-mt-en-zh",
37
+ "Japanese": "Helsinki-NLP/opus-mt-en-jap"
38
+ }
39
+
40
+ translation_tasks = {
41
+ "Spanish": "translation_en_to_es",
42
+ "Vietnamese": "translation_en_to_vi",
43
+ "Indonesian": "translation_en_to_id",
44
+ "Turkish": "translation_en_to_tr",
45
+ "Portuguese": "translation_en_to_pt",
46
+ "Korean": "translation_en_to-ko",
47
+ "Chinese": "translation_en_to_zh",
48
+ "Japanese": "translation_en_to_ja"
49
+ }
50
+
51
+ # ------------------------------------------------------
52
+ # 3. TTS Configuration
53
+ # - MMS TTS (VITS) for: Spanish, Vietnamese, Indonesian, Turkish, Portuguese, Korean
54
+ # - MeloTTS for: Chinese and Japanese
55
+ # ------------------------------------------------------
56
+ tts_config = {
57
+ "Spanish": {"model_id": "facebook/mms-tts-spa", "architecture": "vits", "type": "mms"},
58
+ "Vietnamese": {"model_id": "facebook/mms-tts-vie", "architecture": "vits", "type": "mms"},
59
+ "Indonesian": {"model_id": "facebook/mms-tts-ind", "architecture": "vits", "type": "mms"},
60
+ "Turkish": {"model_id": "facebook/mms-tts-tur", "architecture": "vits", "type": "mms"},
61
+ "Portuguese": {"model_id": "facebook/mms-tts-por", "architecture": "vits", "type": "mms"},
62
+ "Korean": {"model_id": "facebook/mms-tts-kor", "architecture": "vits", "type": "mms"},
63
+ "Chinese": {"type": "melo"},
64
+ "Japanese": {"type": "melo"}
65
+ }
66
+
67
+ # ------------------------------------------------------
68
+ # 4. Global Caches for Translators and TTS Models
69
+ # ------------------------------------------------------
70
+ translator_cache = {}
71
+ mms_tts_cache = {} # For MMS (VITS-based) TTS models
72
+ melo_tts_cache = {} # For MeloTTS models (Chinese/Japanese)
73
+
74
+ # ------------------------------------------------------
75
+ # 5. Translator Helper
76
+ # ------------------------------------------------------
77
+ def get_translator(lang):
78
+ if lang in translator_cache:
79
+ return translator_cache[lang]
80
+ model_name = translation_models[lang]
81
+ task_name = translation_tasks[lang]
82
+ translator = pipeline(task_name, model=model_name)
83
+ translator_cache[lang] = translator
84
+ return translator
85
+
86
+ # ------------------------------------------------------
87
+ # 6. MMS TTS (VITS) Helper for languages using MMS TTS
88
+ # ------------------------------------------------------
89
+ def load_mms_tts(lang):
90
+ if lang in mms_tts_cache:
91
+ return mms_tts_cache[lang]
92
+ config = tts_config[lang]
93
+ try:
94
+ model = VitsModel.from_pretrained(config["model_id"])
95
+ tokenizer = AutoTokenizer.from_pretrained(config["model_id"])
96
+ mms_tts_cache[lang] = (model, tokenizer)
97
+ except Exception as e:
98
+ raise RuntimeError(f"Failed to load MMS TTS model for {lang} ({config['model_id']}): {e}")
99
+ return mms_tts_cache[lang]
100
+
101
+ def run_mms_tts(text, lang):
102
+ model, tokenizer = load_mms_tts(lang)
103
+ inputs = tokenizer(text, return_tensors="pt")
104
+ with torch.no_grad():
105
+ output = model(**inputs)
106
+ if not hasattr(output, "waveform"):
107
+ raise RuntimeError(f"MMS TTS model output for {lang} does not contain 'waveform'.")
108
+ waveform = output.waveform.squeeze().cpu().numpy()
109
+ sample_rate = 16000
110
+ return sample_rate, waveform
111
+
112
+ # ------------------------------------------------------
113
+ # 7. MeloTTS Helper for Chinese and Japanese
114
+ # ------------------------------------------------------
115
+ def run_melo_tts(text, lang):
116
+ """
117
+ Uses the myshell-ai MeloTTS model.
118
+ For Chinese, use language parameter 'ZH'; for Japanese, use 'JP'.
119
+ """
120
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
121
+ lang_param = 'ZH' if lang == "Chinese" else 'JP'
122
+ if lang not in melo_tts_cache:
123
+ try:
124
+ model = MeloTTS(language=lang_param, device=device)
125
+ melo_tts_cache[lang] = model
126
+ except Exception as e:
127
+ raise RuntimeError(f"Failed to load MeloTTS model for {lang}: {e}")
128
+ else:
129
+ model = melo_tts_cache[lang]
130
+ speaker_ids = model.hps.data.spk2id
131
+ # Assume the speaker key is the same as lang_param
132
+ speaker_key = lang_param
133
+ speed = 1.0
134
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
135
+ tmp_name = tmp.name
136
+ try:
137
+ model.tts_to_file(text, speaker_ids[speaker_key], tmp_name, speed=speed)
138
+ data, sr = sf.read(tmp_name)
139
+ finally:
140
+ if os.path.exists(tmp_name):
141
+ os.remove(tmp_name)
142
+ return sr, data
143
+
144
+ # ------------------------------------------------------
145
+ # 8. Main Prediction Function
146
+ # ------------------------------------------------------
147
+ def predict(audio, text, target_language):
148
+ """
149
+ 1. Obtain English text (via ASR if audio provided, else text).
150
+ 2. Translate the English text to target_language.
151
+ 3. Generate TTS audio using either MMS TTS (VITS) or MeloTTS.
152
+ """
153
+ # Step 1: Get English text.
154
+ if text.strip():
155
+ english_text = text.strip()
156
+ elif audio is not None:
157
+ sample_rate, audio_data = audio
158
+ if audio_data.dtype not in [np.float32, np.float64]:
159
+ audio_data = audio_data.astype(np.float32)
160
+ if len(audio_data.shape) > 1 and audio_data.shape[1] > 1:
161
+ audio_data = np.mean(audio_data, axis=1)
162
+ if sample_rate != 16000:
163
+ audio_data = librosa.resample(audio_data, orig_sr=sample_rate, target_sr=16000)
164
+ asr_input = {"array": audio_data, "sampling_rate": 16000}
165
+ asr_result = asr(asr_input)
166
+ english_text = asr_result["text"]
167
+ else:
168
+ return "No input provided.", "", None
169
+
170
+ # Step 2: Translate.
171
+ translator = get_translator(target_language)
172
+ try:
173
+ translation_result = translator(english_text)
174
+ translated_text = translation_result[0]["translation_text"]
175
+ except Exception as e:
176
+ return english_text, f"Translation error: {e}", None
177
+
178
+ # Step 3: TTS.
179
+ try:
180
+ tts_type = tts_config[target_language]["type"]
181
+ if tts_type == "mms":
182
+ sr, waveform = run_mms_tts(translated_text, target_language)
183
+ elif tts_type == "melo":
184
+ sr, waveform = run_melo_tts(translated_text, target_language)
185
+ else:
186
+ raise RuntimeError("Unknown TTS type for target language.")
187
+ except Exception as e:
188
+ return english_text, translated_text, f"TTS error: {e}"
189
+
190
+ return english_text, translated_text, (sr, waveform)
191
+
192
+ # ------------------------------------------------------
193
+ # 9. Gradio Interface
194
+ # ------------------------------------------------------
195
+ language_choices = [
196
+ "Spanish", "Vietnamese", "Indonesian", "Turkish", "Portuguese", "Korean", "Chinese", "Japanese"
197
+ ]
198
+
199
+ iface = gr.Interface(
200
+ fn=predict,
201
+ inputs=[
202
+ gr.Audio(type="numpy", label="Record/Upload English Audio (optional)"),
203
+ gr.Textbox(lines=4, placeholder="Or enter English text here", label="English Text Input (optional)"),
204
+ gr.Dropdown(choices=language_choices, value="Spanish", label="Target Language")
205
+ ],
206
+ outputs=[
207
+ gr.Textbox(label="English Transcription"),
208
+ gr.Textbox(label="Translation (Target Language)"),
209
+ gr.Audio(label="Synthesized Speech")
210
+ ],
211
+ title="Multimodal Language Learning Aid",
212
+ description=(
213
+ "This app performs the following steps:\n"
214
+ "1. Transcribes English speech using Wav2Vec2 (or accepts text input).\n"
215
+ "2. Translates the English text to the target language using Helsinki-NLP MarianMT models.\n"
216
+ "3. Synthesizes speech:\n"
217
+ " - For Spanish, Vietnamese, Indonesian, Turkish, Portuguese, and Korean: uses Facebook MMS TTS (VITS-based).\n"
218
+ " - For Chinese and Japanese: uses myshell-ai MeloTTS models.\n"
219
+ "\nSelect your target language from the dropdown."
220
+ ),
221
+ allow_flagging="never"
222
+ )
223
+
224
+ if __name__ == "__main__":
225
+ iface.launch(server_name="0.0.0.0", server_port=7860)