Update app.py
Browse files
app.py
CHANGED
@@ -7,206 +7,234 @@ import tempfile
|
|
7 |
import os
|
8 |
import re
|
9 |
import base64
|
|
|
|
|
|
|
10 |
|
11 |
app = Flask(__name__)
|
12 |
CORS(app)
|
13 |
|
14 |
-
#
|
15 |
-
whisper_model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
-
#
|
18 |
-
|
19 |
-
"text-generation",
|
20 |
-
model="tiiuae/falcon-rw-1b",
|
21 |
-
max_new_tokens=50, # Reduced token count for shorter responses
|
22 |
-
)
|
23 |
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
def process_response(input_text, generated_text):
|
27 |
-
|
28 |
if not generated_text:
|
29 |
return "I'm not sure what to say about that."
|
30 |
-
|
31 |
# Make sure both are strings
|
32 |
input_text = str(input_text).strip()
|
33 |
generated_text = str(generated_text).strip()
|
34 |
|
35 |
-
#
|
36 |
-
if
|
37 |
-
clean_response = generated_text
|
38 |
-
# Remove the input text from the beginning of the response
|
39 |
-
elif generated_text.startswith(input_text):
|
40 |
clean_response = generated_text[len(input_text):].strip()
|
41 |
else:
|
42 |
clean_response = generated_text.strip()
|
43 |
|
44 |
-
#
|
45 |
if not clean_response:
|
46 |
return "I'm listening."
|
47 |
|
48 |
-
#
|
49 |
-
sentences = re.split(r'(?<=[.!?])\s+', clean_response)
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
if len(meaningful_sentences) > 2:
|
57 |
-
result = " ".join(meaningful_sentences[:2])
|
58 |
-
else:
|
59 |
-
result = " ".join(meaningful_sentences)
|
60 |
else:
|
61 |
-
|
62 |
-
if sentences and any(s.strip() for s in sentences):
|
63 |
-
short_sentences = [s for s in sentences if s.strip()]
|
64 |
-
result = " ".join(short_sentences[:2])
|
65 |
-
else:
|
66 |
-
# Fallback if no good sentences were found
|
67 |
-
result = "I'm not sure what to say about that."
|
68 |
|
69 |
-
#
|
70 |
-
result =
|
71 |
-
|
72 |
-
# Normalize quotes to ASCII equivalents
|
73 |
-
result = normalize_quotes(result)
|
74 |
|
75 |
return result
|
76 |
|
77 |
-
def normalize_quotes(text):
|
78 |
-
"""Replace curly quotes and other problematic Unicode characters with ASCII equivalents"""
|
79 |
-
# Replace curly quotes with straight quotes
|
80 |
-
text = text.replace('"', '"').replace('"', '"')
|
81 |
-
text = text.replace(''', "'").replace(''', "'")
|
82 |
-
# Add more replacements as needed
|
83 |
-
return text
|
84 |
-
|
85 |
-
def remove_repetitions(text):
|
86 |
-
# Simple repetition removal
|
87 |
-
words = text.split()
|
88 |
-
if len(words) <= 5: # Don't process very short responses
|
89 |
-
return text
|
90 |
-
|
91 |
-
result = []
|
92 |
-
for i in range(len(words)):
|
93 |
-
# Check if this word starts a repeated phrase
|
94 |
-
if i < len(words) - 3: # Need at least 3 words to check for repetition
|
95 |
-
# Check if next 3+ words appear earlier in the text
|
96 |
-
is_repetition = False
|
97 |
-
for j in range(3, min(10, len(words) - i)): # Check phrases of length 3 to 10
|
98 |
-
phrase = " ".join(words[i:i+j])
|
99 |
-
if phrase in " ".join(result):
|
100 |
-
is_repetition = True
|
101 |
-
break
|
102 |
-
|
103 |
-
if not is_repetition:
|
104 |
-
result.append(words[i])
|
105 |
-
else:
|
106 |
-
result.append(words[i])
|
107 |
-
|
108 |
-
return " ".join(result)
|
109 |
-
|
110 |
def generate_ai_response(user_input):
|
111 |
-
"""
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
116 |
if not user_input or len(user_input.strip()) < 2:
|
117 |
return "I'm listening. Please say more."
|
118 |
|
119 |
try:
|
120 |
-
|
|
|
121 |
raw_response = llm(user_input)[0]["generated_text"]
|
122 |
|
123 |
# Process to get clean, short response
|
124 |
final_response = process_response(user_input, raw_response)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
|
126 |
return final_response
|
127 |
except Exception as e:
|
128 |
print(f"Error generating AI response: {str(e)}")
|
129 |
-
# Return a default response if anything goes wrong
|
130 |
return "I heard you, but I'm having trouble forming a response right now."
|
131 |
|
132 |
@app.route("/talk", methods=["POST"])
|
133 |
def talk():
|
134 |
if "audio" not in request.files:
|
135 |
return jsonify({"error": "No audio file"}), 400
|
136 |
-
|
137 |
-
# Save audio
|
138 |
audio_file = request.files["audio"]
|
139 |
|
140 |
-
|
141 |
-
|
142 |
-
audio_path = tmp.name
|
143 |
-
audio_file.save(audio_path)
|
144 |
-
|
145 |
-
# Transcribe
|
146 |
try:
|
147 |
-
|
148 |
-
|
|
|
149 |
|
150 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
if not transcription.strip():
|
153 |
-
# Handle empty transcription
|
154 |
final_response = "I didn't catch that. Could you please speak again?"
|
155 |
else:
|
156 |
-
# Use the centralized function to generate a response
|
157 |
final_response = generate_ai_response(transcription)
|
158 |
|
159 |
-
|
160 |
-
|
161 |
-
print(f"Transcription error: {str(e)}")
|
162 |
-
final_response = "I had trouble understanding that. Could you try again?"
|
163 |
-
|
164 |
-
# Prepare TTS output path
|
165 |
-
tts_audio_path = audio_path.replace(".wav", "_reply.wav")
|
166 |
-
|
167 |
-
try:
|
168 |
-
# Synthesize speech
|
169 |
-
tts.tts_to_file(text=final_response, file_path=tts_audio_path)
|
170 |
-
|
171 |
-
if not os.path.exists(tts_audio_path) or os.path.getsize(tts_audio_path) == 0:
|
172 |
-
raise Exception("TTS failed to generate audio file")
|
173 |
-
|
174 |
-
except Exception as e:
|
175 |
-
print(f"TTS error: {str(e)}")
|
176 |
-
# If TTS fails, generate a simple audio file with a message
|
177 |
-
# In a production app, you might want to have a pre-recorded fallback audio
|
178 |
-
tts_audio_path = audio_path # Just reuse the input path for now
|
179 |
-
final_response = "Sorry, I couldn't generate audio right now."
|
180 |
-
|
181 |
-
# Return both the audio file and the text response
|
182 |
-
try:
|
183 |
-
response = send_file(tts_audio_path, mimetype="audio/wav")
|
184 |
|
185 |
-
#
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
|
|
189 |
|
190 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
191 |
except Exception as e:
|
192 |
-
print(f"Error
|
193 |
-
return jsonify({
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
-
if 'tts_audio_path' in locals() and os.path.exists(tts_audio_path) and tts_audio_path != audio_path:
|
207 |
-
os.unlink(tts_audio_path)
|
208 |
-
except Exception as cleanup_error:
|
209 |
-
print(f"Error cleaning up files: {str(cleanup_error)}")
|
210 |
|
211 |
@app.route("/chat", methods=["POST"])
|
212 |
def chat():
|
@@ -214,23 +242,42 @@ def chat():
|
|
214 |
if not data or "text" not in data:
|
215 |
return jsonify({"error": "Missing 'text' in request body"}), 400
|
216 |
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
@app.route("/")
|
232 |
def index():
|
233 |
return "Metaverse AI Character API running."
|
234 |
|
235 |
if __name__ == "__main__":
|
236 |
-
|
|
|
|
7 |
import os
|
8 |
import re
|
9 |
import base64
|
10 |
+
import threading
|
11 |
+
import queue
|
12 |
+
import time
|
13 |
|
14 |
app = Flask(__name__)
|
15 |
CORS(app)
|
16 |
|
17 |
+
# Global variables to hold models and caches
|
18 |
+
whisper_model = None
|
19 |
+
llm = None
|
20 |
+
tts = None
|
21 |
+
response_cache = {}
|
22 |
+
model_lock = threading.Lock()
|
23 |
+
models_loaded = False
|
24 |
+
loading_thread = None
|
25 |
+
load_queue = queue.Queue()
|
26 |
|
27 |
+
# Use a smaller Whisper model for faster inference
|
28 |
+
WHISPER_MODEL_SIZE = "tiny" # Changed from "small" to "tiny"
|
|
|
|
|
|
|
|
|
29 |
|
30 |
+
def load_models():
|
31 |
+
"""Load all models in background thread"""
|
32 |
+
global whisper_model, llm, tts, models_loaded
|
33 |
+
|
34 |
+
print("Starting model loading...")
|
35 |
+
|
36 |
+
# Load Whisper model with optimized settings
|
37 |
+
whisper_model = WhisperModel(
|
38 |
+
WHISPER_MODEL_SIZE,
|
39 |
+
device="cpu",
|
40 |
+
compute_type="int8",
|
41 |
+
download_root="./models" # Cache models to disk
|
42 |
+
)
|
43 |
+
print("Whisper model loaded")
|
44 |
+
|
45 |
+
# Use a smaller, faster LLM
|
46 |
+
llm = pipeline(
|
47 |
+
"text-generation",
|
48 |
+
model="distilgpt2", # Much smaller than falcon-rw-1b
|
49 |
+
max_new_tokens=40, # Further reduce token count
|
50 |
+
device="cpu"
|
51 |
+
)
|
52 |
+
print("LLM loaded")
|
53 |
+
|
54 |
+
# Load TTS model
|
55 |
+
tts = TTS(
|
56 |
+
model_name="tts_models/en/ljspeech/fast_pitch", # Using faster model
|
57 |
+
progress_bar=False,
|
58 |
+
gpu=False
|
59 |
+
)
|
60 |
+
print("TTS model loaded")
|
61 |
+
|
62 |
+
with model_lock:
|
63 |
+
models_loaded = True
|
64 |
+
|
65 |
+
print("All models loaded successfully")
|
66 |
+
|
67 |
+
# Process any pending requests that arrived during loading
|
68 |
+
while not load_queue.empty():
|
69 |
+
callback = load_queue.get()
|
70 |
+
callback()
|
71 |
+
|
72 |
+
# Start loading models in background thread
|
73 |
+
def start_loading_models():
|
74 |
+
global loading_thread
|
75 |
+
loading_thread = threading.Thread(target=load_models)
|
76 |
+
loading_thread.daemon = True
|
77 |
+
loading_thread.start()
|
78 |
+
|
79 |
+
start_loading_models()
|
80 |
+
|
81 |
+
def ensure_models_loaded(callback):
|
82 |
+
"""Ensure models are loaded before processing a request"""
|
83 |
+
with model_lock:
|
84 |
+
if models_loaded:
|
85 |
+
# Models already loaded, process immediately
|
86 |
+
callback()
|
87 |
+
else:
|
88 |
+
# Queue the callback for when models finish loading
|
89 |
+
load_queue.put(callback)
|
90 |
+
return jsonify({
|
91 |
+
"status": "loading",
|
92 |
+
"message": "Models are still loading. Please try again in a moment."
|
93 |
+
}), 503
|
94 |
|
95 |
def process_response(input_text, generated_text):
|
96 |
+
"""Process and clean up LLM response - optimized for speed"""
|
97 |
if not generated_text:
|
98 |
return "I'm not sure what to say about that."
|
99 |
+
|
100 |
# Make sure both are strings
|
101 |
input_text = str(input_text).strip()
|
102 |
generated_text = str(generated_text).strip()
|
103 |
|
104 |
+
# Extract the response portion (everything after the input)
|
105 |
+
if generated_text.startswith(input_text):
|
|
|
|
|
|
|
106 |
clean_response = generated_text[len(input_text):].strip()
|
107 |
else:
|
108 |
clean_response = generated_text.strip()
|
109 |
|
110 |
+
# Fallback for empty responses
|
111 |
if not clean_response:
|
112 |
return "I'm listening."
|
113 |
|
114 |
+
# Simplified sentence extraction - just get first sentence for faster response
|
115 |
+
sentences = re.split(r'(?<=[.!?])\s+', clean_response, maxsplit=2)
|
116 |
+
if sentences:
|
117 |
+
# Just use the first sentence for maximum speed
|
118 |
+
result = sentences[0].strip()
|
119 |
+
# Add second sentence if it's not too long
|
120 |
+
if len(sentences) > 1 and len(sentences[1]) < 30:
|
121 |
+
result += " " + sentences[1].strip()
|
|
|
|
|
|
|
|
|
122 |
else:
|
123 |
+
result = clean_response
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
|
125 |
+
# Simple normalization
|
126 |
+
result = result.replace('"', '"').replace('"', '"')
|
127 |
+
result = result.replace(''', "'").replace(''', "'")
|
|
|
|
|
128 |
|
129 |
return result
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
def generate_ai_response(user_input):
|
132 |
+
"""Generate AI responses - with caching for speed"""
|
133 |
+
# Check cache for identical requests to avoid recomputation
|
134 |
+
cache_key = user_input.strip().lower()
|
135 |
+
if cache_key in response_cache:
|
136 |
+
print("Cache hit!")
|
137 |
+
return response_cache[cache_key]
|
138 |
+
|
139 |
if not user_input or len(user_input.strip()) < 2:
|
140 |
return "I'm listening. Please say more."
|
141 |
|
142 |
try:
|
143 |
+
start_time = time.time()
|
144 |
+
# Generate response with fewer tokens
|
145 |
raw_response = llm(user_input)[0]["generated_text"]
|
146 |
|
147 |
# Process to get clean, short response
|
148 |
final_response = process_response(user_input, raw_response)
|
149 |
+
print(f"LLM processing time: {time.time() - start_time:.2f}s")
|
150 |
+
|
151 |
+
# Cache the response for future identical requests
|
152 |
+
response_cache[cache_key] = final_response
|
153 |
+
|
154 |
+
# Limit cache size to prevent memory issues
|
155 |
+
if len(response_cache) > 100:
|
156 |
+
# Remove oldest entries (simple approach)
|
157 |
+
keys_to_remove = list(response_cache.keys())[:-50]
|
158 |
+
for k in keys_to_remove:
|
159 |
+
response_cache.pop(k, None)
|
160 |
|
161 |
return final_response
|
162 |
except Exception as e:
|
163 |
print(f"Error generating AI response: {str(e)}")
|
|
|
164 |
return "I heard you, but I'm having trouble forming a response right now."
|
165 |
|
166 |
@app.route("/talk", methods=["POST"])
|
167 |
def talk():
|
168 |
if "audio" not in request.files:
|
169 |
return jsonify({"error": "No audio file"}), 400
|
170 |
+
|
|
|
171 |
audio_file = request.files["audio"]
|
172 |
|
173 |
+
def process_request():
|
174 |
+
nonlocal audio_file
|
|
|
|
|
|
|
|
|
175 |
try:
|
176 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
177 |
+
audio_path = tmp.name
|
178 |
+
audio_file.save(audio_path)
|
179 |
|
180 |
+
# Transcribe with optimized settings
|
181 |
+
transcribe_start = time.time()
|
182 |
+
segments, _ = whisper_model.transcribe(
|
183 |
+
audio_path,
|
184 |
+
beam_size=1, # Reduce beam size for speed
|
185 |
+
vad_filter=True, # Use voice activity detection to process only speech
|
186 |
+
vad_parameters=dict(min_silence_duration_ms=500) # Tune VAD for speed
|
187 |
+
)
|
188 |
+
transcription = "".join([seg.text for seg in segments])
|
189 |
+
print(f"Transcription time: {time.time() - transcribe_start:.2f}s")
|
190 |
|
191 |
if not transcription.strip():
|
|
|
192 |
final_response = "I didn't catch that. Could you please speak again?"
|
193 |
else:
|
|
|
194 |
final_response = generate_ai_response(transcription)
|
195 |
|
196 |
+
# Prepare TTS output path
|
197 |
+
tts_audio_path = audio_path.replace(".wav", "_reply.wav")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
|
199 |
+
# Synthesize speech with optimized settings
|
200 |
+
tts_start = time.time()
|
201 |
+
tts.tts_to_file(
|
202 |
+
text=final_response,
|
203 |
+
file_path=tts_audio_path,
|
204 |
+
speaker_wav=None,
|
205 |
+
speed=1.1 # Slightly faster speech
|
206 |
+
)
|
207 |
+
print(f"TTS time: {time.time() - tts_start:.2f}s")
|
208 |
|
209 |
+
# Return both the audio file and the text response
|
210 |
+
try:
|
211 |
+
response = send_file(tts_audio_path, mimetype="audio/wav")
|
212 |
+
encoded_response = base64.b64encode(final_response.encode('utf-8')).decode('ascii')
|
213 |
+
response.headers["X-Response-Text-Base64"] = encoded_response
|
214 |
+
response.headers["Access-Control-Expose-Headers"] = "X-Response-Text-Base64"
|
215 |
+
return response
|
216 |
+
except Exception as e:
|
217 |
+
print(f"Error sending file: {str(e)}")
|
218 |
+
return jsonify({
|
219 |
+
"error": "Could not send audio response",
|
220 |
+
"text_response": final_response
|
221 |
+
}), 500
|
222 |
+
|
223 |
except Exception as e:
|
224 |
+
print(f"Error in talk endpoint: {str(e)}")
|
225 |
+
return jsonify({"error": str(e)}), 500
|
226 |
+
finally:
|
227 |
+
# Clean up temporary files
|
228 |
+
try:
|
229 |
+
if 'audio_path' in locals() and os.path.exists(audio_path):
|
230 |
+
os.unlink(audio_path)
|
231 |
+
if 'tts_audio_path' in locals() and os.path.exists(tts_audio_path) and tts_audio_path != audio_path:
|
232 |
+
os.unlink(tts_audio_path)
|
233 |
+
except Exception as cleanup_error:
|
234 |
+
print(f"Error cleaning up files: {str(cleanup_error)}")
|
235 |
+
|
236 |
+
# Ensure models are loaded before processing
|
237 |
+
return ensure_models_loaded(process_request)
|
|
|
|
|
|
|
|
|
238 |
|
239 |
@app.route("/chat", methods=["POST"])
|
240 |
def chat():
|
|
|
242 |
if not data or "text" not in data:
|
243 |
return jsonify({"error": "Missing 'text' in request body"}), 400
|
244 |
|
245 |
+
user_input = data["text"]
|
246 |
+
|
247 |
+
def process_request():
|
248 |
+
try:
|
249 |
+
print(f"Text input: {user_input}") # Debugging
|
250 |
+
|
251 |
+
# Start timing
|
252 |
+
start_time = time.time()
|
253 |
+
|
254 |
+
# Generate response
|
255 |
+
final_response = generate_ai_response(user_input)
|
256 |
+
|
257 |
+
# Report timing
|
258 |
+
print(f"Total processing time: {time.time() - start_time:.2f}s")
|
259 |
+
|
260 |
+
return jsonify({"response": final_response})
|
261 |
+
except Exception as e:
|
262 |
+
print(f"Error in chat endpoint: {str(e)}")
|
263 |
+
return jsonify({"response": "I'm having trouble processing that. Could you try again?", "error": str(e)})
|
264 |
+
|
265 |
+
# Ensure models are loaded before processing
|
266 |
+
return ensure_models_loaded(process_request)
|
267 |
+
|
268 |
+
@app.route("/status", methods=["GET"])
|
269 |
+
def status():
|
270 |
+
"""Check if models are loaded and ready"""
|
271 |
+
with model_lock:
|
272 |
+
if models_loaded:
|
273 |
+
return jsonify({"status": "ready", "message": "All models loaded and ready"})
|
274 |
+
else:
|
275 |
+
return jsonify({"status": "loading", "message": "Models are still loading"})
|
276 |
|
277 |
@app.route("/")
|
278 |
def index():
|
279 |
return "Metaverse AI Character API running."
|
280 |
|
281 |
if __name__ == "__main__":
|
282 |
+
# Use threaded server for better concurrency
|
283 |
+
app.run(host="0.0.0.0", port=7860, threaded=True)
|