Jerich commited on
Commit
ec27d4d
·
verified ·
1 Parent(s): c10c930

Add STT functionality with openai/whisper-tiny

Browse files
Files changed (1) hide show
  1. app.py +193 -15
app.py CHANGED
@@ -1,12 +1,104 @@
 
 
 
 
1
  import logging
2
- from fastapi import FastAPI, HTTPException, Form
 
 
 
 
 
3
  from fastapi.responses import JSONResponse
 
4
 
5
  # Configure logging
6
  logging.basicConfig(level=logging.INFO)
7
- logger = logging.getLogger("minimal-api")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- app = FastAPI(title="Minimal API Test")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  @app.get("/")
12
  async def root():
@@ -16,23 +108,109 @@ async def root():
16
 
17
  @app.get("/health")
18
  async def health_check():
19
- """Health check endpoint to confirm the app is running"""
 
20
  logger.info("Health check requested")
21
- return {"status": "healthy"}
 
 
 
 
 
 
22
 
23
- @app.get("/ping")
24
- async def ping():
25
- """Simple ping endpoint to test GET requests"""
26
- logger.info("Ping requested")
27
- return {"message": "pong"}
 
28
 
29
- @app.post("/echo")
30
- async def echo(text: str = Form(...)):
31
- """Echo endpoint to test POST requests with form data"""
32
  if not text:
33
  raise HTTPException(status_code=400, detail="No text provided")
34
- logger.info(f"Echo requested with text: {text}")
35
- return {"received_text": text}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
  if __name__ == "__main__":
38
  import uvicorn
 
1
+ import os
2
+ os.environ["HOME"] = "/root"
3
+ os.environ["HF_HOME"] = "/tmp/hf_cache"
4
+
5
  import logging
6
+ import threading
7
+ import tempfile
8
+ import uuid
9
+ import numpy as np
10
+ import soundfile as sf
11
+ from fastapi import FastAPI, HTTPException, UploadFile, File, Form
12
  from fastapi.responses import JSONResponse
13
+ from typing import Dict, Any, Optional
14
 
15
  # Configure logging
16
  logging.basicConfig(level=logging.INFO)
17
+ logger = logging.getLogger("talklas-api")
18
+
19
+ app = FastAPI(title="Talklas API")
20
+
21
+ # Global variables to track application state
22
+ models_loaded = False
23
+ loading_in_progress = False
24
+ loading_thread = None
25
+ model_status = {
26
+ "stt": "not_loaded",
27
+ "mt": "not_loaded",
28
+ "tts": "not_loaded"
29
+ }
30
+ error_message = None
31
+
32
+ # STT model and processor (will be loaded in background)
33
+ stt_processor = None
34
+ stt_model = None
35
+
36
+ # Define the valid languages
37
+ LANGUAGE_MAPPING = {
38
+ "English": "eng",
39
+ "Tagalog": "tgl",
40
+ "Cebuano": "ceb",
41
+ "Ilocano": "ilo",
42
+ "Waray": "war",
43
+ "Pangasinan": "pag"
44
+ }
45
 
46
+ # Function to load models in background
47
+ def load_models_task():
48
+ global models_loaded, loading_in_progress, model_status, error_message, stt_processor, stt_model
49
+
50
+ try:
51
+ loading_in_progress = True
52
+
53
+ # Import heavy libraries only when needed
54
+ logger.info("Starting to load STT model...")
55
+ import torch
56
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
57
+
58
+ # Load STT model
59
+ try:
60
+ logger.info("Loading Whisper model...")
61
+ model_status["stt"] = "loading"
62
+ stt_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
63
+ stt_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
64
+ device = "cuda" if torch.cuda.is_available() else "cpu"
65
+ stt_model.to(device)
66
+ logger.info("STT model loaded successfully")
67
+ model_status["stt"] = "loaded"
68
+ except Exception as e:
69
+ logger.error(f"Failed to load STT model: {str(e)}")
70
+ model_status["stt"] = "failed"
71
+ error_message = f"STT model loading failed: {str(e)}"
72
+ return
73
+
74
+ # Skip MT and TTS models for now to save memory
75
+ model_status["mt"] = "skipped"
76
+ model_status["tts"] = "skipped"
77
+ logger.info("MT and TTS models skipped to save memory")
78
+
79
+ models_loaded = True
80
+ logger.info("Model loading completed successfully")
81
+
82
+ except Exception as e:
83
+ error_message = str(e)
84
+ logger.error(f"Error in model loading task: {str(e)}")
85
+ finally:
86
+ loading_in_progress = False
87
+
88
+ # Start loading models in background
89
+ def start_model_loading():
90
+ global loading_thread, loading_in_progress
91
+ if not loading_in_progress and not models_loaded:
92
+ loading_in_progress = True
93
+ loading_thread = threading.Thread(target=load_models_task)
94
+ loading_thread.daemon = True
95
+ loading_thread.start()
96
+
97
+ # Start the background process when the app starts
98
+ @app.on_event("startup")
99
+ async def startup_event():
100
+ logger.info("Application starting up...")
101
+ start_model_loading()
102
 
103
  @app.get("/")
104
  async def root():
 
108
 
109
  @app.get("/health")
110
  async def health_check():
111
+ """Health check endpoint that always returns successfully"""
112
+ global models_loaded, loading_in_progress, model_status, error_message
113
  logger.info("Health check requested")
114
+ return {
115
+ "status": "healthy",
116
+ "models_loaded": models_loaded,
117
+ "loading_in_progress": loading_in_progress,
118
+ "model_status": model_status,
119
+ "error": error_message
120
+ }
121
 
122
+ @app.post("/update-languages")
123
+ async def update_languages(source_lang: str = Form(...), target_lang: str = Form(...)):
124
+ if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
125
+ raise HTTPException(status_code=400, detail="Invalid language selected")
126
+ logger.info(f"Updating languages: {source_lang} → {target_lang}")
127
+ return {"status": f"Languages updated to {source_lang} → {target_lang}"}
128
 
129
+ @app.post("/translate-text")
130
+ async def translate_text(text: str = Form(...), source_lang: str = Form(...), target_lang: str = Form(...)):
131
+ """Endpoint that creates a placeholder for text translation"""
132
  if not text:
133
  raise HTTPException(status_code=400, detail="No text provided")
134
+ if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
135
+ raise HTTPException(status_code=400, detail="Invalid language selected")
136
+
137
+ logger.info(f"Translate-text requested: {text} from {source_lang} to {target_lang}")
138
+ request_id = str(uuid.uuid4())
139
+ return {
140
+ "request_id": request_id,
141
+ "status": "processing",
142
+ "message": "Translation not implemented yet (MT model not loaded).",
143
+ "source_text": text,
144
+ "translated_text": "Translation not available",
145
+ "output_audio": None
146
+ }
147
+
148
+ @app.post("/translate-audio")
149
+ async def translate_audio(audio: UploadFile = File(...), source_lang: str = Form(...), target_lang: str = Form(...)):
150
+ """Endpoint to transcribe audio using STT"""
151
+ global stt_processor, stt_model
152
+
153
+ if not audio:
154
+ raise HTTPException(status_code=400, detail="No audio file provided")
155
+ if source_lang not in LANGUAGE_MAPPING or target_lang not in LANGUAGE_MAPPING:
156
+ raise HTTPException(status_code=400, detail="Invalid language selected")
157
+
158
+ logger.info(f"Translate-audio requested: {audio.filename} from {source_lang} to {target_lang}")
159
+ request_id = str(uuid.uuid4())
160
+
161
+ # Check if STT model is loaded
162
+ if model_status["stt"] != "loaded" or stt_processor is None or stt_model is None:
163
+ logger.warning("STT model not loaded, returning placeholder response")
164
+ return {
165
+ "request_id": request_id,
166
+ "status": "processing",
167
+ "message": "STT model not loaded yet. Please try again later.",
168
+ "source_text": "Transcription not available",
169
+ "translated_text": "Translation not available",
170
+ "output_audio": None
171
+ }
172
+
173
+ # Save the uploaded audio to a temporary file
174
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as temp_file:
175
+ temp_file.write(await audio.read())
176
+ temp_path = temp_file.name
177
+
178
+ try:
179
+ # Read and preprocess the audio
180
+ waveform, sample_rate = sf.read(temp_path)
181
+ if sample_rate != 16000:
182
+ logger.info(f"Resampling audio from {sample_rate} Hz to 16000 Hz")
183
+ import librosa
184
+ waveform = librosa.resample(waveform, orig_sr=sample_rate, target_sr=16000)
185
+
186
+ # Process the audio with Whisper
187
+ device = "cuda" if torch.cuda.is_available() else "cpu"
188
+ inputs = stt_processor(waveform, sampling_rate=16000, return_tensors="pt").to(device)
189
+ with torch.no_grad():
190
+ generated_ids = stt_model.generate(**inputs)
191
+ transcription = stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
192
+
193
+ logger.info(f"Transcription completed: {transcription}")
194
+ return {
195
+ "request_id": request_id,
196
+ "status": "completed",
197
+ "message": "Transcription completed successfully. Translation and TTS not implemented yet.",
198
+ "source_text": transcription,
199
+ "translated_text": "Translation not available",
200
+ "output_audio": None
201
+ }
202
+ except Exception as e:
203
+ logger.error(f"Error during transcription: {str(e)}")
204
+ return {
205
+ "request_id": request_id,
206
+ "status": "failed",
207
+ "message": f"Transcription failed: {str(e)}",
208
+ "source_text": "Transcription not available",
209
+ "translated_text": "Translation not available",
210
+ "output_audio": None
211
+ }
212
+ finally:
213
+ os.unlink(temp_path)
214
 
215
  if __name__ == "__main__":
216
  import uvicorn