mac9087 commited on
Commit
a77cb2e
·
verified ·
1 Parent(s): d997bfa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +203 -156
app.py CHANGED
@@ -7,206 +7,234 @@ import tempfile
7
  import os
8
  import re
9
  import base64
 
 
 
10
 
11
  app = Flask(__name__)
12
  CORS(app)
13
 
14
- # Load models
15
- whisper_model = WhisperModel("small", device="cpu", compute_type="int8")
 
 
 
 
 
 
 
16
 
17
- # Configure the LLM for short, conversational responses
18
- llm = pipeline(
19
- "text-generation",
20
- model="tiiuae/falcon-rw-1b",
21
- max_new_tokens=50, # Reduced token count for shorter responses
22
- )
23
 
24
- tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", progress_bar=False, gpu=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  def process_response(input_text, generated_text):
27
- # Handle the case where generated_text might be None
28
  if not generated_text:
29
  return "I'm not sure what to say about that."
30
-
31
  # Make sure both are strings
32
  input_text = str(input_text).strip()
33
  generated_text = str(generated_text).strip()
34
 
35
- # Skip empty input
36
- if not input_text:
37
- clean_response = generated_text
38
- # Remove the input text from the beginning of the response
39
- elif generated_text.startswith(input_text):
40
  clean_response = generated_text[len(input_text):].strip()
41
  else:
42
  clean_response = generated_text.strip()
43
 
44
- # If we ended up with nothing, provide a default response
45
  if not clean_response:
46
  return "I'm listening."
47
 
48
- # Split into sentences and take only the first 1-2 meaningful sentences
49
- sentences = re.split(r'(?<=[.!?])\s+', clean_response)
50
-
51
- # Filter out empty or very short sentences
52
- meaningful_sentences = [s for s in sentences if len(s) > 5]
53
-
54
- # Take just 1-2 sentences for a casual, human-like response
55
- if meaningful_sentences:
56
- if len(meaningful_sentences) > 2:
57
- result = " ".join(meaningful_sentences[:2])
58
- else:
59
- result = " ".join(meaningful_sentences)
60
  else:
61
- # If no meaningful sentences, but we have short sentences, use those
62
- if sentences and any(s.strip() for s in sentences):
63
- short_sentences = [s for s in sentences if s.strip()]
64
- result = " ".join(short_sentences[:2])
65
- else:
66
- # Fallback if no good sentences were found
67
- result = "I'm not sure what to say about that."
68
 
69
- # Remove any repetitive phrases
70
- result = remove_repetitions(result)
71
-
72
- # Normalize quotes to ASCII equivalents
73
- result = normalize_quotes(result)
74
 
75
  return result
76
 
77
- def normalize_quotes(text):
78
- """Replace curly quotes and other problematic Unicode characters with ASCII equivalents"""
79
- # Replace curly quotes with straight quotes
80
- text = text.replace('"', '"').replace('"', '"')
81
- text = text.replace(''', "'").replace(''', "'")
82
- # Add more replacements as needed
83
- return text
84
-
85
- def remove_repetitions(text):
86
- # Simple repetition removal
87
- words = text.split()
88
- if len(words) <= 5: # Don't process very short responses
89
- return text
90
-
91
- result = []
92
- for i in range(len(words)):
93
- # Check if this word starts a repeated phrase
94
- if i < len(words) - 3: # Need at least 3 words to check for repetition
95
- # Check if next 3+ words appear earlier in the text
96
- is_repetition = False
97
- for j in range(3, min(10, len(words) - i)): # Check phrases of length 3 to 10
98
- phrase = " ".join(words[i:i+j])
99
- if phrase in " ".join(result):
100
- is_repetition = True
101
- break
102
-
103
- if not is_repetition:
104
- result.append(words[i])
105
- else:
106
- result.append(words[i])
107
-
108
- return " ".join(result)
109
-
110
  def generate_ai_response(user_input):
111
- """
112
- Centralized function to generate AI responses to ensure consistency
113
- between text and voice responses.
114
- """
115
- # Handle empty or too short input
 
 
116
  if not user_input or len(user_input.strip()) < 2:
117
  return "I'm listening. Please say more."
118
 
119
  try:
120
- # Generate response
 
121
  raw_response = llm(user_input)[0]["generated_text"]
122
 
123
  # Process to get clean, short response
124
  final_response = process_response(user_input, raw_response)
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  return final_response
127
  except Exception as e:
128
  print(f"Error generating AI response: {str(e)}")
129
- # Return a default response if anything goes wrong
130
  return "I heard you, but I'm having trouble forming a response right now."
131
 
132
  @app.route("/talk", methods=["POST"])
133
  def talk():
134
  if "audio" not in request.files:
135
  return jsonify({"error": "No audio file"}), 400
136
-
137
- # Save audio
138
  audio_file = request.files["audio"]
139
 
140
- try:
141
- with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
142
- audio_path = tmp.name
143
- audio_file.save(audio_path)
144
-
145
- # Transcribe
146
  try:
147
- segments, _ = whisper_model.transcribe(audio_path)
148
- transcription = "".join([seg.text for seg in segments])
 
149
 
150
- print(f"Transcription: {transcription}") # Debugging
 
 
 
 
 
 
 
 
 
151
 
152
  if not transcription.strip():
153
- # Handle empty transcription
154
  final_response = "I didn't catch that. Could you please speak again?"
155
  else:
156
- # Use the centralized function to generate a response
157
  final_response = generate_ai_response(transcription)
158
 
159
- print(f"Voice response: {final_response}") # Debugging
160
- except Exception as e:
161
- print(f"Transcription error: {str(e)}")
162
- final_response = "I had trouble understanding that. Could you try again?"
163
-
164
- # Prepare TTS output path
165
- tts_audio_path = audio_path.replace(".wav", "_reply.wav")
166
-
167
- try:
168
- # Synthesize speech
169
- tts.tts_to_file(text=final_response, file_path=tts_audio_path)
170
-
171
- if not os.path.exists(tts_audio_path) or os.path.getsize(tts_audio_path) == 0:
172
- raise Exception("TTS failed to generate audio file")
173
-
174
- except Exception as e:
175
- print(f"TTS error: {str(e)}")
176
- # If TTS fails, generate a simple audio file with a message
177
- # In a production app, you might want to have a pre-recorded fallback audio
178
- tts_audio_path = audio_path # Just reuse the input path for now
179
- final_response = "Sorry, I couldn't generate audio right now."
180
-
181
- # Return both the audio file and the text response
182
- try:
183
- response = send_file(tts_audio_path, mimetype="audio/wav")
184
 
185
- # Base64 encode the response text to avoid Unicode issues in headers
186
- encoded_response = base64.b64encode(final_response.encode('utf-8')).decode('ascii')
187
- response.headers["X-Response-Text-Base64"] = encoded_response
188
- response.headers["Access-Control-Expose-Headers"] = "X-Response-Text-Base64"
 
 
 
 
 
189
 
190
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  except Exception as e:
192
- print(f"Error sending file: {str(e)}")
193
- return jsonify({
194
- "error": "Could not send audio response",
195
- "text_response": final_response
196
- }), 500
197
-
198
- except Exception as e:
199
- print(f"Error in talk endpoint: {str(e)}")
200
- return jsonify({"error": str(e)}), 500
201
- finally:
202
- # Clean up temporary files
203
- try:
204
- if 'audio_path' in locals() and os.path.exists(audio_path):
205
- os.unlink(audio_path)
206
- if 'tts_audio_path' in locals() and os.path.exists(tts_audio_path) and tts_audio_path != audio_path:
207
- os.unlink(tts_audio_path)
208
- except Exception as cleanup_error:
209
- print(f"Error cleaning up files: {str(cleanup_error)}")
210
 
211
  @app.route("/chat", methods=["POST"])
212
  def chat():
@@ -214,23 +242,42 @@ def chat():
214
  if not data or "text" not in data:
215
  return jsonify({"error": "Missing 'text' in request body"}), 400
216
 
217
- try:
218
- user_input = data["text"]
219
- print(f"Text input: {user_input}") # Debugging
220
-
221
- # Use the centralized function to generate a response
222
- final_response = generate_ai_response(user_input)
223
-
224
- print(f"Text response: {final_response}") # Debugging
225
-
226
- return jsonify({"response": final_response})
227
- except Exception as e:
228
- print(f"Error in chat endpoint: {str(e)}")
229
- return jsonify({"response": "I'm having trouble processing that. Could you try again?", "error": str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
  @app.route("/")
232
  def index():
233
  return "Metaverse AI Character API running."
234
 
235
  if __name__ == "__main__":
236
- app.run(host="0.0.0.0", port=7860)
 
 
7
  import os
8
  import re
9
  import base64
10
+ import threading
11
+ import queue
12
+ import time
13
 
14
  app = Flask(__name__)
15
  CORS(app)
16
 
17
+ # Global variables to hold models and caches
18
+ whisper_model = None
19
+ llm = None
20
+ tts = None
21
+ response_cache = {}
22
+ model_lock = threading.Lock()
23
+ models_loaded = False
24
+ loading_thread = None
25
+ load_queue = queue.Queue()
26
 
27
+ # Use a smaller Whisper model for faster inference
28
+ WHISPER_MODEL_SIZE = "tiny" # Changed from "small" to "tiny"
 
 
 
 
29
 
30
+ def load_models():
31
+ """Load all models in background thread"""
32
+ global whisper_model, llm, tts, models_loaded
33
+
34
+ print("Starting model loading...")
35
+
36
+ # Load Whisper model with optimized settings
37
+ whisper_model = WhisperModel(
38
+ WHISPER_MODEL_SIZE,
39
+ device="cpu",
40
+ compute_type="int8",
41
+ download_root="./models" # Cache models to disk
42
+ )
43
+ print("Whisper model loaded")
44
+
45
+ # Use a smaller, faster LLM
46
+ llm = pipeline(
47
+ "text-generation",
48
+ model="distilgpt2", # Much smaller than falcon-rw-1b
49
+ max_new_tokens=40, # Further reduce token count
50
+ device="cpu"
51
+ )
52
+ print("LLM loaded")
53
+
54
+ # Load TTS model
55
+ tts = TTS(
56
+ model_name="tts_models/en/ljspeech/fast_pitch", # Using faster model
57
+ progress_bar=False,
58
+ gpu=False
59
+ )
60
+ print("TTS model loaded")
61
+
62
+ with model_lock:
63
+ models_loaded = True
64
+
65
+ print("All models loaded successfully")
66
+
67
+ # Process any pending requests that arrived during loading
68
+ while not load_queue.empty():
69
+ callback = load_queue.get()
70
+ callback()
71
+
72
+ # Start loading models in background thread
73
+ def start_loading_models():
74
+ global loading_thread
75
+ loading_thread = threading.Thread(target=load_models)
76
+ loading_thread.daemon = True
77
+ loading_thread.start()
78
+
79
+ start_loading_models()
80
+
81
+ def ensure_models_loaded(callback):
82
+ """Ensure models are loaded before processing a request"""
83
+ with model_lock:
84
+ if models_loaded:
85
+ # Models already loaded, process immediately
86
+ callback()
87
+ else:
88
+ # Queue the callback for when models finish loading
89
+ load_queue.put(callback)
90
+ return jsonify({
91
+ "status": "loading",
92
+ "message": "Models are still loading. Please try again in a moment."
93
+ }), 503
94
 
95
  def process_response(input_text, generated_text):
96
+ """Process and clean up LLM response - optimized for speed"""
97
  if not generated_text:
98
  return "I'm not sure what to say about that."
99
+
100
  # Make sure both are strings
101
  input_text = str(input_text).strip()
102
  generated_text = str(generated_text).strip()
103
 
104
+ # Extract the response portion (everything after the input)
105
+ if generated_text.startswith(input_text):
 
 
 
106
  clean_response = generated_text[len(input_text):].strip()
107
  else:
108
  clean_response = generated_text.strip()
109
 
110
+ # Fallback for empty responses
111
  if not clean_response:
112
  return "I'm listening."
113
 
114
+ # Simplified sentence extraction - just get first sentence for faster response
115
+ sentences = re.split(r'(?<=[.!?])\s+', clean_response, maxsplit=2)
116
+ if sentences:
117
+ # Just use the first sentence for maximum speed
118
+ result = sentences[0].strip()
119
+ # Add second sentence if it's not too long
120
+ if len(sentences) > 1 and len(sentences[1]) < 30:
121
+ result += " " + sentences[1].strip()
 
 
 
 
122
  else:
123
+ result = clean_response
 
 
 
 
 
 
124
 
125
+ # Simple normalization
126
+ result = result.replace('"', '"').replace('"', '"')
127
+ result = result.replace(''', "'").replace(''', "'")
 
 
128
 
129
  return result
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  def generate_ai_response(user_input):
132
+ """Generate AI responses - with caching for speed"""
133
+ # Check cache for identical requests to avoid recomputation
134
+ cache_key = user_input.strip().lower()
135
+ if cache_key in response_cache:
136
+ print("Cache hit!")
137
+ return response_cache[cache_key]
138
+
139
  if not user_input or len(user_input.strip()) < 2:
140
  return "I'm listening. Please say more."
141
 
142
  try:
143
+ start_time = time.time()
144
+ # Generate response with fewer tokens
145
  raw_response = llm(user_input)[0]["generated_text"]
146
 
147
  # Process to get clean, short response
148
  final_response = process_response(user_input, raw_response)
149
+ print(f"LLM processing time: {time.time() - start_time:.2f}s")
150
+
151
+ # Cache the response for future identical requests
152
+ response_cache[cache_key] = final_response
153
+
154
+ # Limit cache size to prevent memory issues
155
+ if len(response_cache) > 100:
156
+ # Remove oldest entries (simple approach)
157
+ keys_to_remove = list(response_cache.keys())[:-50]
158
+ for k in keys_to_remove:
159
+ response_cache.pop(k, None)
160
 
161
  return final_response
162
  except Exception as e:
163
  print(f"Error generating AI response: {str(e)}")
 
164
  return "I heard you, but I'm having trouble forming a response right now."
165
 
166
  @app.route("/talk", methods=["POST"])
167
  def talk():
168
  if "audio" not in request.files:
169
  return jsonify({"error": "No audio file"}), 400
170
+
 
171
  audio_file = request.files["audio"]
172
 
173
+ def process_request():
174
+ nonlocal audio_file
 
 
 
 
175
  try:
176
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
177
+ audio_path = tmp.name
178
+ audio_file.save(audio_path)
179
 
180
+ # Transcribe with optimized settings
181
+ transcribe_start = time.time()
182
+ segments, _ = whisper_model.transcribe(
183
+ audio_path,
184
+ beam_size=1, # Reduce beam size for speed
185
+ vad_filter=True, # Use voice activity detection to process only speech
186
+ vad_parameters=dict(min_silence_duration_ms=500) # Tune VAD for speed
187
+ )
188
+ transcription = "".join([seg.text for seg in segments])
189
+ print(f"Transcription time: {time.time() - transcribe_start:.2f}s")
190
 
191
  if not transcription.strip():
 
192
  final_response = "I didn't catch that. Could you please speak again?"
193
  else:
 
194
  final_response = generate_ai_response(transcription)
195
 
196
+ # Prepare TTS output path
197
+ tts_audio_path = audio_path.replace(".wav", "_reply.wav")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ # Synthesize speech with optimized settings
200
+ tts_start = time.time()
201
+ tts.tts_to_file(
202
+ text=final_response,
203
+ file_path=tts_audio_path,
204
+ speaker_wav=None,
205
+ speed=1.1 # Slightly faster speech
206
+ )
207
+ print(f"TTS time: {time.time() - tts_start:.2f}s")
208
 
209
+ # Return both the audio file and the text response
210
+ try:
211
+ response = send_file(tts_audio_path, mimetype="audio/wav")
212
+ encoded_response = base64.b64encode(final_response.encode('utf-8')).decode('ascii')
213
+ response.headers["X-Response-Text-Base64"] = encoded_response
214
+ response.headers["Access-Control-Expose-Headers"] = "X-Response-Text-Base64"
215
+ return response
216
+ except Exception as e:
217
+ print(f"Error sending file: {str(e)}")
218
+ return jsonify({
219
+ "error": "Could not send audio response",
220
+ "text_response": final_response
221
+ }), 500
222
+
223
  except Exception as e:
224
+ print(f"Error in talk endpoint: {str(e)}")
225
+ return jsonify({"error": str(e)}), 500
226
+ finally:
227
+ # Clean up temporary files
228
+ try:
229
+ if 'audio_path' in locals() and os.path.exists(audio_path):
230
+ os.unlink(audio_path)
231
+ if 'tts_audio_path' in locals() and os.path.exists(tts_audio_path) and tts_audio_path != audio_path:
232
+ os.unlink(tts_audio_path)
233
+ except Exception as cleanup_error:
234
+ print(f"Error cleaning up files: {str(cleanup_error)}")
235
+
236
+ # Ensure models are loaded before processing
237
+ return ensure_models_loaded(process_request)
 
 
 
 
238
 
239
  @app.route("/chat", methods=["POST"])
240
  def chat():
 
242
  if not data or "text" not in data:
243
  return jsonify({"error": "Missing 'text' in request body"}), 400
244
 
245
+ user_input = data["text"]
246
+
247
+ def process_request():
248
+ try:
249
+ print(f"Text input: {user_input}") # Debugging
250
+
251
+ # Start timing
252
+ start_time = time.time()
253
+
254
+ # Generate response
255
+ final_response = generate_ai_response(user_input)
256
+
257
+ # Report timing
258
+ print(f"Total processing time: {time.time() - start_time:.2f}s")
259
+
260
+ return jsonify({"response": final_response})
261
+ except Exception as e:
262
+ print(f"Error in chat endpoint: {str(e)}")
263
+ return jsonify({"response": "I'm having trouble processing that. Could you try again?", "error": str(e)})
264
+
265
+ # Ensure models are loaded before processing
266
+ return ensure_models_loaded(process_request)
267
+
268
+ @app.route("/status", methods=["GET"])
269
+ def status():
270
+ """Check if models are loaded and ready"""
271
+ with model_lock:
272
+ if models_loaded:
273
+ return jsonify({"status": "ready", "message": "All models loaded and ready"})
274
+ else:
275
+ return jsonify({"status": "loading", "message": "Models are still loading"})
276
 
277
  @app.route("/")
278
  def index():
279
  return "Metaverse AI Character API running."
280
 
281
  if __name__ == "__main__":
282
+ # Use threaded server for better concurrency
283
+ app.run(host="0.0.0.0", port=7860, threaded=True)