Yilin0601 commited on
Commit
fa64981
·
verified ·
1 Parent(s): c953920

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -38
app.py CHANGED
@@ -9,11 +9,11 @@ import os
9
  from transformers import pipeline, VitsModel, AutoTokenizer
10
  from datasets import load_dataset
11
 
12
- # For MeloTTS (Chinese and Japanese)
13
  try:
14
- from melo.api import TTS as MeloTTS
15
  except ImportError:
16
- raise ImportError("Please install the MeloTTS package (e.g., pip install myshell-ai/MeloTTS-Chinese)")
17
 
18
  # ------------------------------------------------------
19
  # 1. ASR Pipeline (English) using Wav2Vec2
@@ -51,7 +51,7 @@ translation_tasks = {
51
  # ------------------------------------------------------
52
  # 3. TTS Configuration
53
  # - MMS TTS (VITS) for: Spanish, Vietnamese, Indonesian, Turkish, Portuguese, Korean
54
- # - MeloTTS for: Chinese and Japanese
55
  # ------------------------------------------------------
56
  tts_config = {
57
  "Spanish": {"model_id": "facebook/mms-tts-spa", "architecture": "vits", "type": "mms"},
@@ -60,8 +60,14 @@ tts_config = {
60
  "Turkish": {"model_id": "facebook/mms-tts-tur", "architecture": "vits", "type": "mms"},
61
  "Portuguese": {"model_id": "facebook/mms-tts-por", "architecture": "vits", "type": "mms"},
62
  "Korean": {"model_id": "facebook/mms-tts-kor", "architecture": "vits", "type": "mms"},
63
- "Chinese": {"type": "melo"},
64
- "Japanese": {"type": "melo"}
 
 
 
 
 
 
65
  }
66
 
67
  # ------------------------------------------------------
@@ -69,7 +75,7 @@ tts_config = {
69
  # ------------------------------------------------------
70
  translator_cache = {}
71
  mms_tts_cache = {} # For MMS (VITS-based) TTS models
72
- melo_tts_cache = {} # For MeloTTS models (Chinese/Japanese)
73
 
74
  # ------------------------------------------------------
75
  # 5. Translator Helper
@@ -110,31 +116,31 @@ def run_mms_tts(text, lang):
110
  return sample_rate, waveform
111
 
112
  # ------------------------------------------------------
113
- # 7. MeloTTS Helper for Chinese and Japanese
114
  # ------------------------------------------------------
115
- def run_melo_tts(text, lang):
116
- """
117
- Uses the myshell-ai MeloTTS model.
118
- For Chinese, use language parameter 'ZH'; for Japanese, use 'JP'.
119
- """
120
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
121
- lang_param = 'ZH' if lang == "Chinese" else 'JP'
122
- if lang not in melo_tts_cache:
123
- try:
124
- model = MeloTTS(language=lang_param, device=device)
125
- melo_tts_cache[lang] = model
126
- except Exception as e:
127
- raise RuntimeError(f"Failed to load MeloTTS model for {lang}: {e}")
128
- else:
129
- model = melo_tts_cache[lang]
130
- speaker_ids = model.hps.data.spk2id
131
- # Assume the speaker key is the same as lang_param
132
- speaker_key = lang_param
133
- speed = 1.0
134
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
135
  tmp_name = tmp.name
136
  try:
137
- model.tts_to_file(text, speaker_ids[speaker_key], tmp_name, speed=speed)
 
 
 
 
138
  data, sr = sf.read(tmp_name)
139
  finally:
140
  if os.path.exists(tmp_name):
@@ -147,8 +153,8 @@ def run_melo_tts(text, lang):
147
  def predict(audio, text, target_language):
148
  """
149
  1. Obtain English text (via ASR if audio provided, else text).
150
- 2. Translate the English text to target_language.
151
- 3. Generate TTS audio using either MMS TTS (VITS) or MeloTTS.
152
  """
153
  # Step 1: Get English text.
154
  if text.strip():
@@ -180,8 +186,8 @@ def predict(audio, text, target_language):
180
  tts_type = tts_config[target_language]["type"]
181
  if tts_type == "mms":
182
  sr, waveform = run_mms_tts(translated_text, target_language)
183
- elif tts_type == "melo":
184
- sr, waveform = run_melo_tts(translated_text, target_language)
185
  else:
186
  raise RuntimeError("Unknown TTS type for target language.")
187
  except Exception as e:
@@ -212,14 +218,12 @@ iface = gr.Interface(
212
  description=(
213
  "This app performs the following steps:\n"
214
  "1. Transcribes English speech using Wav2Vec2 (or accepts text input).\n"
215
- "2. Translates the English text to the target language using Helsinki-NLP MarianMT models.\n"
216
- "3. Synthesizes speech:\n"
217
- " - For Spanish, Vietnamese, Indonesian, Turkish, Portuguese, and Korean: uses Facebook MMS TTS (VITS-based).\n"
218
- " - For Chinese and Japanese: uses myshell-ai MeloTTS models.\n"
219
- "\nSelect your target language from the dropdown."
220
  ),
221
  allow_flagging="never"
222
  )
223
 
224
  if __name__ == "__main__":
225
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
9
  from transformers import pipeline, VitsModel, AutoTokenizer
10
  from datasets import load_dataset
11
 
12
+ # For Coqui TTS (XTTS-v2)
13
  try:
14
+ from TTS.api import TTS as CoquiTTS
15
  except ImportError:
16
+ raise ImportError("Please install Coqui TTS via pip install TTS.")
17
 
18
  # ------------------------------------------------------
19
  # 1. ASR Pipeline (English) using Wav2Vec2
 
51
  # ------------------------------------------------------
52
  # 3. TTS Configuration
53
  # - MMS TTS (VITS) for: Spanish, Vietnamese, Indonesian, Turkish, Portuguese, Korean
54
+ # - Coqui XTTS-v2 for: Chinese and Japanese
55
  # ------------------------------------------------------
56
  tts_config = {
57
  "Spanish": {"model_id": "facebook/mms-tts-spa", "architecture": "vits", "type": "mms"},
 
60
  "Turkish": {"model_id": "facebook/mms-tts-tur", "architecture": "vits", "type": "mms"},
61
  "Portuguese": {"model_id": "facebook/mms-tts-por", "architecture": "vits", "type": "mms"},
62
  "Korean": {"model_id": "facebook/mms-tts-kor", "architecture": "vits", "type": "mms"},
63
+ "Chinese": {"type": "coqui"},
64
+ "Japanese": {"type": "coqui"}
65
+ }
66
+
67
+ # For Coqui, we map our languages to language codes expected by the model.
68
+ coqui_lang_map = {
69
+ "Chinese": "zh",
70
+ "Japanese": "ja"
71
  }
72
 
73
  # ------------------------------------------------------
 
75
  # ------------------------------------------------------
76
  translator_cache = {}
77
  mms_tts_cache = {} # For MMS (VITS-based) TTS models
78
+ coqui_tts_cache = None # Single instance for Coqui XTTS-v2
79
 
80
  # ------------------------------------------------------
81
  # 5. Translator Helper
 
116
  return sample_rate, waveform
117
 
118
  # ------------------------------------------------------
119
+ # 7. Coqui TTS Helper for Chinese and Japanese
120
  # ------------------------------------------------------
121
+ def load_coqui_tts():
122
+ global coqui_tts_cache
123
+ if coqui_tts_cache is not None:
124
+ return coqui_tts_cache
125
+ try:
126
+ # Set gpu=True if a GPU is available.
127
+ coqui_tts_cache = CoquiTTS("tts_models/multilingual/multi-dataset/xtts_v2", gpu=False)
128
+ except Exception as e:
129
+ raise RuntimeError(f"Failed to load Coqui XTTS-v2 TTS: {e}")
130
+ return coqui_tts_cache
131
+
132
+ def run_coqui_tts(text, lang):
133
+ coqui_tts = load_coqui_tts()
134
+ lang_code = coqui_lang_map[lang] # "zh" for Chinese or "ja" for Japanese
135
+ # Write the output to a temporary file and then read it back.
 
 
 
 
136
  with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
137
  tmp_name = tmp.name
138
  try:
139
+ coqui_tts.tts_to_file(
140
+ text=text,
141
+ file_path=tmp_name,
142
+ language=lang_code # using default voice; for cloning, add speaker_wav parameter
143
+ )
144
  data, sr = sf.read(tmp_name)
145
  finally:
146
  if os.path.exists(tmp_name):
 
153
  def predict(audio, text, target_language):
154
  """
155
  1. Obtain English text (via ASR if audio provided, else text).
156
+ 2. Translate English text to target_language.
157
+ 3. Generate TTS audio using either MMS TTS (VITS) or Coqui XTTS-v2.
158
  """
159
  # Step 1: Get English text.
160
  if text.strip():
 
186
  tts_type = tts_config[target_language]["type"]
187
  if tts_type == "mms":
188
  sr, waveform = run_mms_tts(translated_text, target_language)
189
+ elif tts_type == "coqui":
190
+ sr, waveform = run_coqui_tts(translated_text, target_language)
191
  else:
192
  raise RuntimeError("Unknown TTS type for target language.")
193
  except Exception as e:
 
218
  description=(
219
  "This app performs the following steps:\n"
220
  "1. Transcribes English speech using Wav2Vec2 (or accepts text input).\n"
221
+ "2. Translates the English text to the target language using Helsinki-NLP models.\n"
222
+ "3. Provides Synthetic speech:\n"
223
+ "For Spanish, Vietnamese, Indonesian, Turkish, Portuguese, and Korean."
 
 
224
  ),
225
  allow_flagging="never"
226
  )
227
 
228
  if __name__ == "__main__":
229
+ iface.launch(server_name="0.0.0.0", server_port=7860)