|
from flask import Flask, request, jsonify, send_file |
|
from flask_cors import CORS |
|
from faster_whisper import WhisperModel |
|
from transformers import pipeline |
|
from TTS.api import TTS |
|
import tempfile |
|
import os |
|
import re |
|
import base64 |
|
import threading |
|
import functools |
|
import time |
|
from cachetools import LRUCache, cached, TTLCache |
|
import gc |
|
import psutil |
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
MODEL_CACHE_SIZE = 200 |
|
MODEL_CACHE_TTL = 7200 |
|
USE_GPU = False |
|
|
|
|
|
whisper_model = None |
|
llm = None |
|
tts = None |
|
models_loaded = False |
|
models_lock = threading.Lock() |
|
|
|
|
|
response_cache = TTLCache(maxsize=MODEL_CACHE_SIZE, ttl=MODEL_CACHE_TTL) |
|
|
|
def load_models(): |
|
"""Load models optimized for low CPU environments""" |
|
global whisper_model, llm, tts, models_loaded |
|
|
|
if models_loaded: |
|
return |
|
|
|
with models_lock: |
|
if models_loaded: |
|
return |
|
|
|
print("Loading models for low-resource environment...") |
|
start_time = time.time() |
|
|
|
|
|
gc.collect() |
|
|
|
|
|
device = "cpu" |
|
compute_type = "int8" |
|
|
|
|
|
def log_memory(): |
|
process = psutil.Process(os.getpid()) |
|
memory_info = process.memory_info() |
|
memory_mb = memory_info.rss / 1024 / 1024 |
|
print(f"Memory usage: {memory_mb:.2f} MB") |
|
|
|
|
|
print("Loading whisper model...") |
|
log_memory() |
|
whisper_model = WhisperModel("tiny", device=device, compute_type=compute_type) |
|
|
|
|
|
print("Loading language model...") |
|
log_memory() |
|
llm = pipeline( |
|
"text-generation", |
|
model="tiiuae/falcon-rw-1b", |
|
max_new_tokens=30, |
|
device=-1, |
|
) |
|
|
|
|
|
print("Loading TTS model...") |
|
log_memory() |
|
tts = TTS(model_name="tts_models/en/ljspeech/tacotron2-DDC", |
|
progress_bar=False, |
|
gpu=False) |
|
|
|
|
|
gc.collect() |
|
|
|
models_loaded = True |
|
log_memory() |
|
print(f"Models loaded in {time.time() - start_time:.2f} seconds") |
|
|
|
@cached(cache=response_cache) |
|
def generate_ai_response(user_input): |
|
""" |
|
Generate AI responses with caching to avoid repetitive processing. |
|
Optimized for low CPU environments. |
|
""" |
|
load_models() |
|
|
|
|
|
if not user_input or len(user_input.strip()) < 2: |
|
return "I'm listening. Please say more." |
|
|
|
|
|
normalized_input = user_input.lower().strip() |
|
|
|
|
|
for cached_input in response_cache.keys(): |
|
if cached_input and normalized_input and ( |
|
cached_input.lower() in normalized_input or |
|
normalized_input in cached_input.lower() or |
|
levenshtein_distance(normalized_input, cached_input.lower()) < 5): |
|
print(f"Using cached similar response for: {cached_input}") |
|
return response_cache[cached_input] |
|
|
|
try: |
|
|
|
start_time = time.time() |
|
timeout = 3.0 |
|
|
|
|
|
raw_response = llm(user_input, max_new_tokens=30)[0]["generated_text"] |
|
|
|
|
|
elapsed = time.time() - start_time |
|
if elapsed > timeout: |
|
print(f"Response generation taking too long: {elapsed:.2f}s") |
|
return "Let me think about that for a moment." |
|
|
|
|
|
final_response = process_response(user_input, raw_response) |
|
|
|
|
|
gc.collect() |
|
|
|
return final_response |
|
except Exception as e: |
|
print(f"Error generating AI response: {str(e)}") |
|
|
|
return "I heard you, but I'm having trouble forming a response right now." |
|
|
|
def levenshtein_distance(s1, s2): |
|
""" |
|
Calculate simple string similarity for cache optimization. |
|
A simpler implementation than full Levenshtein to save CPU cycles. |
|
""" |
|
if len(s1) < len(s2): |
|
return levenshtein_distance(s2, s1) |
|
|
|
if not s2: |
|
return len(s1) |
|
|
|
previous_row = range(len(s2) + 1) |
|
for i, c1 in enumerate(s1): |
|
current_row = [i + 1] |
|
for j, c2 in enumerate(s2): |
|
insertions = previous_row[j + 1] + 1 |
|
deletions = current_row[j] + 1 |
|
substitutions = previous_row[j] + (c1 != c2) |
|
current_row.append(min(insertions, deletions, substitutions)) |
|
previous_row = current_row |
|
|
|
return previous_row[-1] |
|
|
|
def process_response(input_text, generated_text): |
|
"""Optimized response processing function""" |
|
|
|
if not generated_text: |
|
return "I'm not sure what to say about that." |
|
|
|
|
|
input_text = str(input_text).strip() |
|
generated_text = str(generated_text).strip() |
|
|
|
|
|
if not input_text: |
|
clean_response = generated_text |
|
|
|
elif generated_text.startswith(input_text): |
|
clean_response = generated_text[len(input_text):].strip() |
|
else: |
|
clean_response = generated_text.strip() |
|
|
|
|
|
if not clean_response: |
|
return "I'm listening." |
|
|
|
|
|
sentences = re.split(r'(?<=[.!?])\s+', clean_response) |
|
|
|
|
|
meaningful_sentences = [s for s in sentences if len(s) > 5] |
|
|
|
|
|
if meaningful_sentences: |
|
if len(meaningful_sentences) > 2: |
|
result = " ".join(meaningful_sentences[:2]) |
|
else: |
|
result = " ".join(meaningful_sentences) |
|
else: |
|
|
|
short_sentences = [s for s in sentences if s.strip()] |
|
if short_sentences: |
|
result = " ".join(short_sentences[:2]) |
|
else: |
|
|
|
result = "I'm not sure what to say about that." |
|
|
|
|
|
result = remove_repetitions(result) |
|
|
|
|
|
result = normalize_quotes(result) |
|
|
|
return result |
|
|
|
def normalize_quotes(text): |
|
"""Replace curly quotes with straight quotes - optimized version""" |
|
replacements = { |
|
'"': '"', '"': '"', |
|
''': "'", ''': "'" |
|
} |
|
for old, new in replacements.items(): |
|
text = text.replace(old, new) |
|
return text |
|
|
|
def remove_repetitions(text): |
|
"""Optimized repetition removal function""" |
|
words = text.split() |
|
if len(words) <= 5: |
|
return text |
|
|
|
result = [] |
|
text_so_far = "" |
|
|
|
for i in range(len(words)): |
|
|
|
if i < len(words) - 3: |
|
|
|
is_repetition = False |
|
|
|
for j in range(3, min(10, len(words) - i)): |
|
phrase = " ".join(words[i:i+j]) |
|
if phrase in text_so_far: |
|
is_repetition = True |
|
break |
|
|
|
if not is_repetition: |
|
result.append(words[i]) |
|
text_so_far += words[i] + " " |
|
else: |
|
result.append(words[i]) |
|
text_so_far += words[i] + " " |
|
|
|
return " ".join(result) |
|
|
|
@app.route("/talk", methods=["POST"]) |
|
def talk(): |
|
"""Optimized voice API endpoint for low-resource environments""" |
|
if "audio" not in request.files: |
|
return jsonify({"error": "No audio file"}), 400 |
|
|
|
|
|
process = psutil.Process(os.getpid()) |
|
memory_before = process.memory_info().rss / 1024 / 1024 |
|
print(f"Memory before processing: {memory_before:.2f} MB") |
|
|
|
|
|
load_models() |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
audio_file = request.files["audio"] |
|
|
|
try: |
|
|
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
|
audio_path = tmp.name |
|
audio_file.save(audio_path) |
|
|
|
|
|
try: |
|
|
|
segments, _ = whisper_model.transcribe( |
|
audio_path, |
|
beam_size=1, |
|
vad_filter=True, |
|
language="en" |
|
) |
|
transcription = "".join([seg.text for seg in segments]) |
|
|
|
print(f"Transcription: {transcription}") |
|
print(f"Transcription time: {time.time() - start_time:.2f}s") |
|
|
|
if not transcription.strip(): |
|
final_response = "I didn't catch that. Could you please speak again?" |
|
else: |
|
|
|
final_response = generate_ai_response(transcription) |
|
|
|
print(f"Voice response: {final_response}") |
|
print(f"Response generation time: {time.time() - start_time:.2f}s") |
|
|
|
|
|
response_hash = str(hash(final_response)) |
|
cached_audio_path = os.path.join(tempfile.gettempdir(), f"cached_response_{response_hash}.wav") |
|
|
|
if os.path.exists(cached_audio_path): |
|
print("Using cached audio response") |
|
tts_audio_path = cached_audio_path |
|
else: |
|
|
|
tts_audio_path = audio_path.replace(".wav", "_reply.wav") |
|
|
|
try: |
|
|
|
tts.tts_to_file( |
|
text=final_response, |
|
file_path=tts_audio_path, |
|
speed=1.1 |
|
) |
|
|
|
if not os.path.exists(tts_audio_path) or os.path.getsize(tts_audio_path) == 0: |
|
raise Exception("TTS failed to generate audio file") |
|
|
|
|
|
if len(final_response) < 100: |
|
try: |
|
import shutil |
|
shutil.copy(tts_audio_path, cached_audio_path) |
|
except Exception as cache_error: |
|
print(f"Error caching audio: {str(cache_error)}") |
|
|
|
except Exception as e: |
|
print(f"TTS error: {str(e)}") |
|
tts_audio_path = audio_path |
|
final_response = "Sorry, I couldn't generate audio right now." |
|
except Exception as e: |
|
print(f"Transcription error: {str(e)}") |
|
final_response = "I had trouble understanding that. Could you try again?" |
|
tts_audio_path = audio_path |
|
|
|
|
|
try: |
|
response = send_file(tts_audio_path, mimetype="audio/wav") |
|
|
|
|
|
encoded_response = base64.b64encode(final_response.encode('utf-8')).decode('ascii') |
|
response.headers["X-Response-Text-Base64"] = encoded_response |
|
response.headers["Access-Control-Expose-Headers"] = "X-Response-Text-Base64" |
|
|
|
|
|
print(f"Total processing time: {time.time() - start_time:.2f}s") |
|
memory_after = process.memory_info().rss / 1024 / 1024 |
|
print(f"Memory after processing: {memory_after:.2f} MB") |
|
|
|
|
|
gc.collect() |
|
|
|
return response |
|
except Exception as e: |
|
print(f"Error sending file: {str(e)}") |
|
return jsonify({ |
|
"error": "Could not send audio response", |
|
"text_response": final_response |
|
}), 500 |
|
|
|
except Exception as e: |
|
print(f"Error in talk endpoint: {str(e)}") |
|
return jsonify({"error": str(e)}), 500 |
|
finally: |
|
|
|
try: |
|
if 'audio_path' in locals() and os.path.exists(audio_path): |
|
os.unlink(audio_path) |
|
if 'tts_audio_path' in locals() and tts_audio_path != cached_audio_path and tts_audio_path != audio_path and os.path.exists(tts_audio_path): |
|
os.unlink(tts_audio_path) |
|
except Exception as cleanup_error: |
|
print(f"Error cleaning up files: {str(cleanup_error)}") |
|
|
|
|
|
gc.collect() |
|
|
|
@app.route("/chat", methods=["POST"]) |
|
def chat(): |
|
data = request.get_json() |
|
if not data or "text" not in data: |
|
return jsonify({"error": "Missing 'text' in request body"}), 400 |
|
|
|
|
|
load_models() |
|
|
|
try: |
|
user_input = data["text"] |
|
print(f"Text input: {user_input}") |
|
|
|
|
|
final_response = generate_ai_response(user_input) |
|
|
|
print(f"Text response: {final_response}") |
|
|
|
return jsonify({"response": final_response}) |
|
except Exception as e: |
|
print(f"Error in chat endpoint: {str(e)}") |
|
return jsonify({"response": "I'm having trouble processing that. Could you try again?", "error": str(e)}) |
|
|
|
@app.route("/") |
|
def index(): |
|
return "Metaverse AI Character API running." |
|
|
|
|
|
tts_audio_cache = {} |
|
|
|
|
|
def precache_common_responses(): |
|
"""Pre-generate audio for common responses to save processing time""" |
|
common_responses = [ |
|
"I didn't catch that. Could you please speak again?", |
|
"I'm listening. Please say more.", |
|
"I heard you, but I'm having trouble forming a response right now.", |
|
"I'm not sure what to say about that.", |
|
"Let me think about that for a moment." |
|
] |
|
|
|
global tts |
|
if tts is None: |
|
load_models() |
|
|
|
print("Pre-caching common audio responses...") |
|
for response in common_responses: |
|
try: |
|
response_hash = str(hash(response)) |
|
cached_path = os.path.join(tempfile.gettempdir(), f"cached_response_{response_hash}.wav") |
|
|
|
if not os.path.exists(cached_path): |
|
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp: |
|
tmp_path = tmp.name |
|
|
|
tts.tts_to_file(text=response, file_path=tmp_path) |
|
os.rename(tmp_path, cached_path) |
|
|
|
tts_audio_cache[response] = cached_path |
|
print(f"Cached: {response}") |
|
except Exception as e: |
|
print(f"Failed to cache response '{response}': {str(e)}") |
|
|
|
print("Finished pre-caching") |
|
|
|
|
|
@app.route("/health", methods=["GET"]) |
|
def health_check(): |
|
"""Health check endpoint to verify API is running""" |
|
memory_usage = psutil.Process(os.getpid()).memory_info().rss / 1024 / 1024 |
|
|
|
return jsonify({ |
|
"status": "ok", |
|
"models_loaded": models_loaded, |
|
"memory_usage_mb": round(memory_usage, 2), |
|
"cache_size": len(response_cache), |
|
"uptime_seconds": time.time() - startup_time |
|
}) |
|
|
|
|
|
startup_time = time.time() |
|
|
|
if __name__ == "__main__": |
|
print("Starting Metaverse AI Character API (Optimized for real-time on 2vCPU)...") |
|
|
|
|
|
model_thread = threading.Thread(target=load_models) |
|
model_thread.daemon = True |
|
model_thread.start() |
|
|
|
|
|
cache_thread = threading.Thread(target=precache_common_responses) |
|
cache_thread.daemon = True |
|
cache_thread.start() |
|
|
|
|
|
|
|
app.run( |
|
host="0.0.0.0", |
|
port=7860, |
|
threaded=True, |
|
|
|
debug=False, |
|
use_reloader=False |
|
) |