Athspi commited on
Commit
22004d7
·
verified ·
1 Parent(s): f5e7901

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -214
app.py CHANGED
@@ -1,253 +1,155 @@
1
- import os
2
- import uuid
3
- import tempfile
4
- import logging
5
  from fastapi import FastAPI, Query, HTTPException
6
- from fastapi.responses import FileResponse, JSONResponse
7
-
8
- # --- gTTS Imports ---
9
- from gtts import gTTS, gTTSError
10
-
11
- # --- Hugging Face Imports ---
12
- # !! These imports might take time on first run or if downloading models !!
13
- try:
14
- from transformers import AutoProcessor, AutoModelForSpeechSeq2Seq
15
- import torch
16
- import scipy.io.wavfile
17
- # import soundfile as sf # Alternative audio saving
18
- print("Hugging Face Transformers loaded.")
19
- except ImportError:
20
- print("Error: Required Hugging Face libraries (transformers, torch, onnxruntime, scipy) not found.")
21
- print("Please install them: pip install -r requirements.txt")
22
- # Exit if core HF libraries are missing, as the HF endpoint won't work
23
- exit("Missing critical Hugging Face dependencies.")
24
-
25
 
26
- # --- Configuration & Setup ---
 
27
  TEMP_DIR = tempfile.gettempdir()
 
28
  os.makedirs(TEMP_DIR, exist_ok=True)
29
 
30
- # Configure logging
31
- logging.basicConfig(level=logging.INFO)
32
- logger = logging.getLogger(__name__)
33
-
34
- # --- Hugging Face Model Loading ---
35
- # Load the model and processor once when the application starts
36
- # This can take time and memory!
37
- HF_MODEL_ID = "willwade/mms-tts-multilingual-models-onnx"
38
- hf_processor = None
39
- hf_model = None
40
- hf_supported_langs = set() # Will store supported ISO 639-3 codes
41
-
42
- try:
43
- logger.info(f"Loading Hugging Face processor: {HF_MODEL_ID}...")
44
- hf_processor = AutoProcessor.from_pretrained(HF_MODEL_ID)
45
- logger.info(f"Loading Hugging Face model: {HF_MODEL_ID} (this may take time)...")
46
- # Specify provider=['CPUExecutionProvider'] if you don't have CUDA or want to force CPU
47
- hf_model = AutoModelForSpeechSeq2Seq.from_pretrained(HF_MODEL_ID, provider=['CPUExecutionProvider']) # Forces CPU via ONNX Runtime provider
48
-
49
- # Determine supported languages from model config (assuming standard MMS config)
50
- if hasattr(hf_model, 'config') and hasattr(hf_model.config, 'id2lang'):
51
- hf_supported_langs = set(hf_model.config.id2lang.values())
52
- logger.info(f"HF Model Supported Languages (ISO 639-3): {sorted(list(hf_supported_langs))}")
53
- else:
54
- logger.warning("Could not automatically determine supported languages from HF model config.")
55
- # Add known languages manually if needed, or leave empty to skip validation
56
- # hf_supported_langs = {'eng', 'spa', 'fra', ...}
57
-
58
- logger.info("Hugging Face model and processor loaded successfully.")
59
-
60
- except Exception as e:
61
- logger.error(f"FATAL: Failed to load Hugging Face model '{HF_MODEL_ID}': {e}", exc_info=True)
62
- # Depending on deployment, you might want the app to fail startup
63
- # Or allow it to run with only gTTS available
64
- logger.warning("Proceeding without Hugging Face TTS functionality.")
65
- hf_model = None # Ensure model is None if loading failed
66
-
67
-
68
  # --- FastAPI App Initialization ---
69
  app = FastAPI(
70
- title="Multi TTS API",
71
- description="API for Text-to-Speech using both gTTS and a Hugging Face MMS model.",
72
- version="2.0.0",
73
  )
74
 
75
-
76
  # --- API Endpoints ---
77
 
78
  @app.get("/", tags=["General"])
79
  def read_root():
80
  """
81
- Root endpoint providing a welcome message and available TTS engines.
82
  """
83
- engines = ["gTTS (/tts/gtts)"]
84
- if hf_model is not None:
85
- engines.append(f"HuggingFace MMS ({HF_MODEL_ID}) (/tts/hf)")
86
-
87
- return {
88
- "message": "Welcome to the Multi TTS API.",
89
- "available_engines": engines
90
- }
91
 
92
- # --- gTTS Endpoint ---
93
- @app.get("/tts/gtts", tags=["TTS - gTTS"])
94
- def text_to_speech_gtts(
95
  text: str = Query(
96
- ...,
97
  min_length=1,
98
- max_length=500,
99
  title="Text to Convert",
100
- description="The text to convert using Google Text-to-Speech."
101
  ),
102
  lang: str = Query(
103
- "en",
104
  min_length=2,
105
- max_length=10,
106
- title="Language Code (BCP 47)",
107
- description="The BCP 47 language code for gTTS (e.g., 'en', 'es', 'fr', 'zh-CN')."
108
  )
109
  ):
110
  """
111
- Generates speech using Google Text-to-Speech (gTTS).
112
- Returns an MP3 audio file.
 
 
113
  """
114
- logger.info(f"gTTS request received: lang='{lang}', text='{text[:50]}...'")
115
  try:
116
- # Generate a unique MP3 filename
117
- filename_mp3 = os.path.join(TEMP_DIR, f"gtts_{uuid.uuid4().hex}.mp3")
118
 
 
 
119
  tts_object = gTTS(text=text, lang=lang, slow=False)
120
- tts_object.save(filename_mp3)
121
- logger.info(f"gTTS generated audio file: {filename_mp3}")
122
 
 
 
 
 
 
123
  return FileResponse(
124
- path=filename_mp3,
125
  media_type="audio/mpeg",
126
- filename=f"gtts_speech_{lang}.mp3"
 
127
  )
128
 
129
  except gTTSError as e:
130
- logger.error(f"gTTS Error: {e} (lang={lang})", exc_info=False) # Don't need full stack trace for common gTTS errors
131
- raise HTTPException(status_code=400, detail=f"gTTS Error: {str(e)}. Ensure language '{lang}' is supported by gTTS.")
132
- except Exception as e:
133
- logger.error(f"Unexpected error in gTTS endpoint: {e}", exc_info=True)
134
- raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
135
-
136
- # --- Hugging Face MMS Endpoint ---
137
- @app.get("/tts/hf", tags=["TTS - HuggingFace MMS"])
138
- def text_to_speech_hf(
139
- text: str = Query(
140
- ...,
141
- min_length=1,
142
- max_length=500, # MMS might have different limits, adjust if known
143
- title="Text to Convert",
144
- description="The text to convert using the Hugging Face MMS model."
145
- ),
146
- lang_code: str = Query(
147
- "eng", # Default to English ISO code
148
- min_length=3,
149
- max_length=3,
150
- title="Language Code (ISO 639-3)",
151
- description=f"The ISO 639-3 language code for the MMS model (e.g., 'eng', 'spa', 'fra'). Supported: {sorted(list(hf_supported_langs)) if hf_supported_langs else 'Unknown - check logs/model card'}"
152
- )
153
- ):
154
- """
155
- Generates speech using the Hugging Face MMS model (`willwade/mms-tts-multilingual-models-onnx`).
156
- Returns a WAV audio file.
157
- """
158
- logger.info(f"HF MMS request received: lang_code='{lang_code}', text='{text[:50]}...'")
159
-
160
- if hf_model is None or hf_processor is None:
161
- logger.warning("HF endpoint called, but model/processor not loaded.")
162
- raise HTTPException(status_code=503, detail="Hugging Face TTS service is currently unavailable.")
163
-
164
- # --- Language Validation ---
165
- if hf_supported_langs and lang_code not in hf_supported_langs:
166
- logger.warning(f"Unsupported language code '{lang_code}' requested for HF model.")
167
- raise HTTPException(
168
- status_code=400,
169
- detail=f"Unsupported language code: '{lang_code}'. Supported codes (ISO 639-3): {sorted(list(hf_supported_langs))}"
170
- )
171
 
172
- try:
173
- # --- Preprocessing ---
174
- logger.debug("Preprocessing text with HF processor...")
175
- # MMS models often don't need language specified in the processor if handled by speaker_id in generate
176
- inputs = hf_processor(text, return_tensors="pt")
177
-
178
- # --- Speaker ID / Language Selection for Generation ---
179
- # The willwade model uses language codes directly mapped in its config
180
- target_lang_id = None
181
- if hasattr(hf_model, 'config') and hasattr(hf_model.config, 'lang_code_to_id'):
182
- target_lang_id = hf_model.config.lang_code_to_id.get(lang_code)
183
-
184
- if target_lang_id is None:
185
- logger.error(f"Could not find target language ID for code '{lang_code}' in model config.")
186
- # This check might be redundant if the initial validation passed, but good safeguard
187
- raise HTTPException(status_code=500, detail=f"Internal configuration error: Cannot map language code '{lang_code}' to model ID.")
188
-
189
- logger.debug(f"Generating speech with HF model for lang_code '{lang_code}' (ID: {target_lang_id})...")
190
- # --- Generation (using torch.no_grad for inference efficiency) ---
191
- with torch.no_grad():
192
- # Use speaker_id to specify the target language for MMS models
193
- # output_dict = hf_model.generate(**inputs, speaker_id=target_lang_id, return_dict=True) # Use this if output_attentions etc. needed
194
- outputs = hf_model.generate(**inputs, speaker_id=target_lang_id)
195
-
196
- # Extract waveform - adjust key/index based on actual model output structure if needed
197
- # Typically the primary output tensor is the waveform
198
- waveform = outputs[0].cpu().numpy().squeeze() # Get waveform, move to CPU, convert to numpy, remove batch dim if present
199
- logger.debug(f"Generated waveform shape: {waveform.shape}")
200
-
201
- if waveform.ndim != 1 or waveform.size == 0:
202
- logger.error(f"Unexpected waveform shape or size: {waveform.shape}")
203
- raise ValueError("Generated audio waveform is invalid.")
204
-
205
- # Get sampling rate from model config
206
- sampling_rate = hf_model.config.sampling_rate
207
- logger.debug(f"Using sampling rate: {sampling_rate}")
208
-
209
- # --- Save as WAV ---
210
- filename_wav = os.path.join(TEMP_DIR, f"hf_mms_{uuid.uuid4().hex}.wav")
211
- logger.info(f"Saving generated audio to WAV file: {filename_wav}")
212
-
213
- # Using scipy to write WAV
214
- # Ensure waveform is in the correct format (float32 or int16 typically)
215
- # MMS models usually output float32 between -1.0 and 1.0
216
- if waveform.dtype != 'float32':
217
- logger.warning(f"Waveform dtype is {waveform.dtype}, converting to float32 for saving.")
218
- waveform = waveform.astype('float32')
219
- # Scale if necessary, but MMS usually outputs in [-1, 1] range suitable for float32 wav
220
- scipy.io.wavfile.write(filename_wav, sampling_rate, waveform)
221
-
222
- # # Alternative using soundfile (often more robust)
223
- # sf.write(filename_wav, waveform, sampling_rate, subtype='FLOAT') # Use 'PCM_16' if int16 desired
224
-
225
- return FileResponse(
226
- path=filename_wav,
227
- media_type="audio/wav",
228
- filename=f"hf_mms_speech_{lang_code}.wav"
229
- )
230
-
231
- except ValueError as e: # Catch specific errors like invalid waveform
232
- logger.error(f"Value error during HF TTS processing: {e}", exc_info=True)
233
- raise HTTPException(status_code=400, detail=f"Input or Processing Error: {str(e)}")
234
  except Exception as e:
235
- logger.error(f"Unexpected error in HF MMS endpoint: {e}", exc_info=True)
236
- raise HTTPException(status_code=500, detail=f"Internal Server Error: {str(e)}")
237
-
238
-
239
- # --- How to Run ---
240
- # 1. Save this code as `app.py`.
241
- # 2. Create `requirements.txt` (as shown above).
242
- # 3. Install dependencies: `pip install -r requirements.txt`
243
- # 4. Run the FastAPI server: `uvicorn app:app --reload`
244
- # (Use `--host 0.0.0.0` if running in Docker or need external access)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  #
246
- # --- Example Usage ---
247
- # - gTTS (English): http://127.0.0.1:8000/tts/gtts?text=Hello+from+gTTS&lang=en
248
- # - gTTS (Spanish): http://127.0.0.1:8000/tts/gtts?text=Hola+desde+gTTS&lang=es
249
  #
250
- # - HF MMS (English): http://127.0.0.1:8000/tts/hf?text=Hello+from+the+MMS+model&lang_code=eng
251
- # - HF MMS (Spanish): http://127.0.0.1:8000/tts/hf?text=Hola+desde+el+modelo+MMS&lang_code=spa
252
- # - HF MMS (French): http://127.0.0.1:8000/tts/hf?text=Bonjour+du+modèle+MMS&lang_code=fra
253
- # (Check supported 'lang_code' values from server logs or model card)
 
1
+ # main.py
 
 
 
2
  from fastapi import FastAPI, Query, HTTPException
3
+ from fastapi.responses import FileResponse
4
+ from gtts import gTTS, gTTSError # Import gTTSError for specific error handling
5
+ import uuid
6
+ import os # Import os module for path operations
7
+ import tempfile # Import tempfile for better temporary file handling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ # --- Configuration ---
10
+ # Use tempfile to get a cross-platform temporary directory
11
  TEMP_DIR = tempfile.gettempdir()
12
+ # Ensure the temporary directory exists
13
  os.makedirs(TEMP_DIR, exist_ok=True)
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # --- FastAPI App Initialization ---
16
  app = FastAPI(
17
+ title="gTTS API",
18
+ description="A simple API to convert text to speech using Google Text-to-Speech (gTTS). Supports multiple languages including Tamil, Sinhala, and many others.",
19
+ version="1.2.0", # Increment version for documentation changes
20
  )
21
 
 
22
  # --- API Endpoints ---
23
 
24
  @app.get("/", tags=["General"])
25
  def read_root():
26
  """
27
+ Root endpoint providing a welcome message.
28
  """
29
+ return {"message": "Welcome to the gTTS API. Use the /tts endpoint to generate speech."}
 
 
 
 
 
 
 
30
 
31
+ @app.get("/tts", tags=["Text-to-Speech"])
32
+ def text_to_speech(
 
33
  text: str = Query(
34
+ ..., # Ellipsis makes the parameter required
35
  min_length=1,
36
+ max_length=500, # Adjust max length as needed
37
  title="Text to Convert",
38
+ description="The text you want to convert into speech (1-500 characters)."
39
  ),
40
  lang: str = Query(
41
+ "en", # Default language is English
42
  min_length=2,
43
+ max_length=10, # Allow for language codes like 'en-us', 'zh-CN' etc.
44
+ title="Language Code",
45
+ description="The BCP 47 language code for the speech synthesis (e.g., 'en', 'es', 'ta', 'si', 'ja', 'zh-CN'). See gTTS documentation for supported languages."
46
  )
47
  ):
48
  """
49
+ Converts the provided text into an MP3 audio file using the specified language.
50
+
51
+ - **text**: The text to synthesize (required).
52
+ - **lang**: The language code (e.g., 'en', 'es', 'fr', 'ta', 'si'). Defaults to 'en'. **Crucially, gTTS must support this language code.**
53
  """
 
54
  try:
55
+ # Generate a unique filename in the configured temporary directory
56
+ filename = os.path.join(TEMP_DIR, f"{uuid.uuid4().hex}.mp3")
57
 
58
+ # Create gTTS object with text and language
59
+ # Use slow=False for normal speed speech
60
  tts_object = gTTS(text=text, lang=lang, slow=False)
 
 
61
 
62
+ # Save the audio file
63
+ tts_object.save(filename)
64
+
65
+ # Return the audio file as a response
66
+ # The 'filename' parameter sets the download name for the browser
67
  return FileResponse(
68
+ path=filename,
69
  media_type="audio/mpeg",
70
+ filename=f"speech_{lang}.mp3" # Suggest a filename like speech_en.mp3 or speech_ta.mp3
71
+ # Consider adding background task for cleanup as mentioned in previous examples
72
  )
73
 
74
  except gTTSError as e:
75
+ # Handle specific gTTS errors (like invalid language code, network issues)
76
+ detail_message = f"gTTS Error: {str(e)}. Ensure the language code '{lang}' is supported and text is appropriate for the language."
77
+ # Check common error patterns
78
+ if "400 (Bad Request)" in str(e) or "Language not supported" in str(e):
79
+ raise HTTPException(status_code=400, detail=detail_message)
80
+ elif "500 (Internal Server Error)" in str(e) or "Failed to connect" in str(e):
81
+ # Treat these as potential temporary Google service issues
82
+ raise HTTPException(status_code=503, detail=f"Service Error: {str(e)}. Could be a temporary issue with the TTS service.")
83
+ else: # Other gTTS errors
84
+ raise HTTPException(status_code=503, detail=detail_message) # 503 Service Unavailable likely
85
+
86
+ except ValueError as e:
87
+ # Potentially handle other value errors if gTTS raises them for certain inputs
88
+ raise HTTPException(status_code=400, detail=f"Input Error: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  except Exception as e:
91
+ # Catch any other unexpected errors
92
+ # Log the error for debugging
93
+ # import logging
94
+ # logging.exception(f"An unexpected error occurred during TTS generation for lang='{lang}'")
95
+ raise HTTPException(status_code=500, detail=f"Internal Server Error: An unexpected error occurred.")
96
+
97
+ # --- How to Run (Instructions) ---
98
+ # 1. Save this code as `main.py`.
99
+ # 2. Install necessary libraries:
100
+ # pip install fastapi "uvicorn[standard]" gTTS
101
+ # 3. Run the FastAPI server using Uvicorn:
102
+ # uvicorn main:app --reload
103
+ #
104
+ # --- How to Use - Examples ---
105
+ # Open your browser or use a tool like curl/Postman.
106
+ # Access the TTS endpoint with the 'text' and 'lang' query parameters.
107
+ # NOTE: Text containing non-ASCII characters needs to be URL-encoded. Most browsers do this automatically.
108
+ #
109
+ # - English (en - Default):
110
+ # Text: "Hello, world!"
111
+ # URL: http://127.0.0.1:8000/tts?text=Hello%2C%20world%21
112
+ #
113
+ # - Spanish (es):
114
+ # Text: "Hola Mundo"
115
+ # URL: http://127.0.0.1:8000/tts?text=Hola%20Mundo&lang=es
116
+ #
117
+ # - French (fr):
118
+ # Text: "Bonjour le monde"
119
+ # URL: http://127.0.0.1:8000/tts?text=Bonjour%20le%20monde&lang=fr
120
+ #
121
+ # - German (de):
122
+ # Text: "Hallo Welt"
123
+ # URL: http://127.0.0.1:8000/tts?text=Hallo%20Welt&lang=de
124
+ #
125
+ # - Tamil (ta):
126
+ # Text: "வணக்கம் உலகம்"
127
+ # URL: http://127.0.0.1:8000/tts?text=%E0%AE%B5%E0%AE%A3%E0%AE%95%E0%AF%8D%E0%AE%95%E0%AE%AE%E0%AF%8D%20%E0%AE%89%E0%AE%B2%E0%AE%95%E0%AE%AE%E0%AF%8D&lang=ta
128
+ #
129
+ # - Sinhala (si):
130
+ # Text: "හෙලෝ ලෝකය"
131
+ # URL: http://127.0.0.1:8000/tts?text=%E0%B7%84%E0%B7%99%E0%B6%BD%E0%B7%9D%20%E0%B6%BD%E0%B7%9D%E0%B6%9A%E0%B6%BA&lang=si
132
+ #
133
+ # - Japanese (ja):
134
+ # Text: "こんにちは世界"
135
+ # URL: http://127.0.0.1:8000/tts?text=%E3%81%93%E3%82%93%E3%81%AB%E3%81%A1%E3%81%AF%E4%B8%96%E7%95%8C&lang=ja
136
+ #
137
+ # - Chinese (Mandarin, Simplified) (zh-CN):
138
+ # Text: "你好世界"
139
+ # URL: http://127.0.0.1:8000/tts?text=%E4%BD%A0%E5%A5%BD%E4%B8%96%E7%95%8C&lang=zh-CN
140
+ #
141
+ # - Russian (ru):
142
+ # Text: "Привет мир"
143
+ # URL: http://127.0.0.1:8000/tts?text=%D0%9F%D1%80%D0%B8%D0%B2%D0%B5%D1%82%20%D0%BC%D0%B8%D1%80&lang=ru
144
+ #
145
+ # - Hindi (hi):
146
+ # Text: "नमस्ते दुनिया"
147
+ # URL: http://127.0.0.1:8000/tts?text=%E0%A4%A8%E0%A4%AE%E0%A4%B8%E0%A5%8D%E0%A4%A4%E0%A5%87%20%E0%A4%A6%E0%A5%81%E0%A4%A8%E0%A4%BF%E0%A4%AF%E0%A4%BE&lang=hi
148
  #
149
+ # - Arabic (ar):
150
+ # Text: "مرحبا بالعالم"
151
+ # URL: http://127.0.0.1:8000/tts?text=%D9%85%D8%B1%D8%AD%D8%A8%D8%A7%20%D8%A8%D8%A7%D9%84%D8%B9%D8%A7%D9%84%D9%85&lang=ar
152
  #
153
+ # Find more supported language codes in the gTTS documentation or common lists of BCP 47 codes.
154
+ # The API will return an MP3 file download or playback depending on your browser/client.
155
+ # If you provide an unsupported language code, you should get a 400 Bad Request error.