mac9087 commited on
Commit
596a84e
·
verified ·
1 Parent(s): a77cb2e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +135 -34
app.py CHANGED
@@ -24,8 +24,13 @@ models_loaded = False
24
  loading_thread = None
25
  load_queue = queue.Queue()
26
 
27
- # Use a smaller Whisper model for faster inference
28
- WHISPER_MODEL_SIZE = "tiny" # Changed from "small" to "tiny"
 
 
 
 
 
29
 
30
  def load_models():
31
  """Load all models in background thread"""
@@ -33,13 +38,14 @@ def load_models():
33
 
34
  print("Starting model loading...")
35
 
36
- # Load Whisper model with optimized settings
37
- whisper_model = WhisperModel(
38
- WHISPER_MODEL_SIZE,
39
- device="cpu",
40
- compute_type="int8",
41
- download_root="./models" # Cache models to disk
42
- )
 
43
  print("Whisper model loaded")
44
 
45
  # Use a smaller, faster LLM
@@ -62,13 +68,47 @@ def load_models():
62
  with model_lock:
63
  models_loaded = True
64
 
65
- print("All models loaded successfully")
 
 
 
 
 
 
 
 
66
 
67
  # Process any pending requests that arrived during loading
68
  while not load_queue.empty():
69
  callback = load_queue.get()
70
  callback()
71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  # Start loading models in background thread
73
  def start_loading_models():
74
  global loading_thread
@@ -76,7 +116,16 @@ def start_loading_models():
76
  loading_thread.daemon = True
77
  loading_thread.start()
78
 
79
- start_loading_models()
 
 
 
 
 
 
 
 
 
80
 
81
  def ensure_models_loaded(callback):
82
  """Ensure models are loaded before processing a request"""
@@ -140,13 +189,18 @@ def generate_ai_response(user_input):
140
  return "I'm listening. Please say more."
141
 
142
  try:
143
- start_time = time.time()
144
- # Generate response with fewer tokens
145
- raw_response = llm(user_input)[0]["generated_text"]
146
-
147
- # Process to get clean, short response
148
- final_response = process_response(user_input, raw_response)
149
- print(f"LLM processing time: {time.time() - start_time:.2f}s")
 
 
 
 
 
150
 
151
  # Cache the response for future identical requests
152
  response_cache[cache_key] = final_response
@@ -173,19 +227,31 @@ def talk():
173
  def process_request():
174
  nonlocal audio_file
175
  try:
176
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
 
177
  audio_path = tmp.name
178
  audio_file.save(audio_path)
179
 
180
- # Transcribe with optimized settings
181
  transcribe_start = time.time()
182
- segments, _ = whisper_model.transcribe(
183
- audio_path,
184
- beam_size=1, # Reduce beam size for speed
185
- vad_filter=True, # Use voice activity detection to process only speech
186
- vad_parameters=dict(min_silence_duration_ms=500) # Tune VAD for speed
187
- )
188
- transcription = "".join([seg.text for seg in segments])
 
 
 
 
 
 
 
 
 
 
 
189
  print(f"Transcription time: {time.time() - transcribe_start:.2f}s")
190
 
191
  if not transcription.strip():
@@ -196,14 +262,29 @@ def talk():
196
  # Prepare TTS output path
197
  tts_audio_path = audio_path.replace(".wav", "_reply.wav")
198
 
199
- # Synthesize speech with optimized settings
200
  tts_start = time.time()
201
- tts.tts_to_file(
202
- text=final_response,
203
- file_path=tts_audio_path,
204
- speaker_wav=None,
205
- speed=1.1 # Slightly faster speech
206
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
207
  print(f"TTS time: {time.time() - tts_start:.2f}s")
208
 
209
  # Return both the audio file and the text response
@@ -278,6 +359,26 @@ def status():
278
  def index():
279
  return "Metaverse AI Character API running."
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  if __name__ == "__main__":
282
  # Use threaded server for better concurrency
283
  app.run(host="0.0.0.0", port=7860, threaded=True)
 
24
  loading_thread = None
25
  load_queue = queue.Queue()
26
 
27
+ # Define paths with proper permissions
28
+ TEMP_DIR = "/tmp/ai_models"
29
+ os.makedirs(TEMP_DIR, exist_ok=True)
30
+
31
+ # Environment variable to control model size
32
+ # Set to "tiny" for fastest response, "base" for better quality but still fast
33
+ WHISPER_MODEL_SIZE = os.environ.get("WHISPER_MODEL_SIZE", "tiny")
34
 
35
  def load_models():
36
  """Load all models in background thread"""
 
38
 
39
  print("Starting model loading...")
40
 
41
+ try:
42
+ # Load Whisper model with optimized settings
43
+ whisper_model = WhisperModel(
44
+ WHISPER_MODEL_SIZE,
45
+ device="cpu",
46
+ compute_type="int8",
47
+ download_root=TEMP_DIR # Use temp directory with write permissions
48
+ )
49
  print("Whisper model loaded")
50
 
51
  # Use a smaller, faster LLM
 
68
  with model_lock:
69
  models_loaded = True
70
 
71
+ except Exception as e:
72
+ print(f"Error loading Whisper model: {str(e)}")
73
+ whisper_model = None
74
+
75
+ # Mark models as loaded even if some failed - we'll use fallbacks
76
+ with model_lock:
77
+ models_loaded = True
78
+
79
+ print("Model loading completed")
80
 
81
  # Process any pending requests that arrived during loading
82
  while not load_queue.empty():
83
  callback = load_queue.get()
84
  callback()
85
 
86
+ # Fallback methods for when models fail to load
87
+ def fallback_transcribe(audio_path):
88
+ """Simple fallback when Whisper fails to load"""
89
+ # Just return empty text - in production you might want a more sophisticated fallback
90
+ return "I couldn't transcribe the audio due to technical issues."
91
+
92
+ def fallback_generate_text(user_input):
93
+ """Simple rule-based response when LLM fails to load"""
94
+ # Very basic template responses
95
+ if not user_input or len(user_input) < 5:
96
+ return "I'm listening. Please continue."
97
+
98
+ if "?" in user_input:
99
+ return "That's an interesting question. I'm processing it now."
100
+
101
+ # Simple acknowledgment responses
102
+ responses = [
103
+ "I understand what you're saying.",
104
+ "I'm following your thoughts.",
105
+ "I hear you loud and clear.",
106
+ "I'm processing that information.",
107
+ "That makes sense to me."
108
+ ]
109
+ import random
110
+ return random.choice(responses)
111
+
112
  # Start loading models in background thread
113
  def start_loading_models():
114
  global loading_thread
 
116
  loading_thread.daemon = True
117
  loading_thread.start()
118
 
119
+ # Create temp directory and start loading
120
+ try:
121
+ os.makedirs(TEMP_DIR, exist_ok=True)
122
+ print(f"Created model cache directory at {TEMP_DIR}")
123
+ start_loading_models()
124
+ except Exception as e:
125
+ print(f"Error setting up model loading: {str(e)}")
126
+ # Automatically mark as loaded with no models
127
+ with model_lock:
128
+ models_loaded = True
129
 
130
  def ensure_models_loaded(callback):
131
  """Ensure models are loaded before processing a request"""
 
189
  return "I'm listening. Please say more."
190
 
191
  try:
192
+ # If LLM failed to load, use fallback
193
+ if llm is None:
194
+ print("Using fallback text generation")
195
+ final_response = fallback_generate_text(user_input)
196
+ else:
197
+ start_time = time.time()
198
+ # Generate response with fewer tokens
199
+ raw_response = llm(user_input)[0]["generated_text"]
200
+
201
+ # Process to get clean, short response
202
+ final_response = process_response(user_input, raw_response)
203
+ print(f"LLM processing time: {time.time() - start_time:.2f}s")
204
 
205
  # Cache the response for future identical requests
206
  response_cache[cache_key] = final_response
 
227
  def process_request():
228
  nonlocal audio_file
229
  try:
230
+ # Prepare file paths
231
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav", dir=TEMP_DIR) as tmp:
232
  audio_path = tmp.name
233
  audio_file.save(audio_path)
234
 
235
+ # Transcribe audio
236
  transcribe_start = time.time()
237
+ if whisper_model is None:
238
+ # Fallback if model failed to load
239
+ print("Using fallback transcription")
240
+ transcription = fallback_transcribe(audio_path)
241
+ else:
242
+ try:
243
+ # Transcribe with optimized settings
244
+ segments, _ = whisper_model.transcribe(
245
+ audio_path,
246
+ beam_size=1, # Reduce beam size for speed
247
+ vad_filter=True, # Use voice activity detection to process only speech
248
+ vad_parameters=dict(min_silence_duration_ms=500) # Tune VAD for speed
249
+ )
250
+ transcription = "".join([seg.text for seg in segments])
251
+ except Exception as e:
252
+ print(f"Whisper transcription error: {str(e)}")
253
+ transcription = ""
254
+
255
  print(f"Transcription time: {time.time() - transcribe_start:.2f}s")
256
 
257
  if not transcription.strip():
 
262
  # Prepare TTS output path
263
  tts_audio_path = audio_path.replace(".wav", "_reply.wav")
264
 
265
+ # Synthesize speech
266
  tts_start = time.time()
267
+ if tts is None:
268
+ # If TTS failed to load, create a simple audio file with message
269
+ print("Using fallback TTS (no speech synthesis)")
270
+ # Just copy the input file as a placeholder
271
+ import shutil
272
+ shutil.copyfile(audio_path, tts_audio_path)
273
+ else:
274
+ try:
275
+ # Synthesize speech with optimized settings
276
+ tts.tts_to_file(
277
+ text=final_response,
278
+ file_path=tts_audio_path,
279
+ speaker_wav=None,
280
+ speed=1.1 # Slightly faster speech
281
+ )
282
+ except Exception as e:
283
+ print(f"TTS error: {str(e)}")
284
+ # Just copy the input file as a placeholder
285
+ import shutil
286
+ shutil.copyfile(audio_path, tts_audio_path)
287
+
288
  print(f"TTS time: {time.time() - tts_start:.2f}s")
289
 
290
  # Return both the audio file and the text response
 
359
  def index():
360
  return "Metaverse AI Character API running."
361
 
362
+ # Add direct-response mode for maximum performance
363
+ @app.route("/quick_chat", methods=["POST"])
364
+ def quick_chat():
365
+ """Ultra-fast endpoint that skips ML models completely for instant responses"""
366
+ data = request.get_json()
367
+ if not data or "text" not in data:
368
+ return jsonify({"error": "Missing 'text' in request body"}), 400
369
+
370
+ try:
371
+ user_input = data["text"]
372
+ print(f"Quick chat input: {user_input}")
373
+
374
+ # Use simple rule-based responses for maximum speed
375
+ final_response = fallback_generate_text(user_input)
376
+
377
+ return jsonify({"response": final_response})
378
+ except Exception as e:
379
+ print(f"Error in quick_chat: {str(e)}")
380
+ return jsonify({"response": "I'm listening."})
381
+
382
  if __name__ == "__main__":
383
  # Use threaded server for better concurrency
384
  app.run(host="0.0.0.0", port=7860, threaded=True)