shayan5422's picture
Upload 5 files
ec1f977 verified
raw
history blame
12.6 kB
import os
os.environ['TOKENIZERS_PARALLELISM'] = 'false' # Disable tokenizer parallelism
from flask import Flask, request, jsonify
from flask_cors import CORS
import numpy as np
import json
import traceback
import logging # Added for background task logging
import threading # Added for background task
import time # Added for background task
import schedule # Added for background task
# --- Import the daily update function ---
try:
from daily_update import main as run_daily_update
# Set up logging for the daily_update module if it uses logging
# logging.getLogger('daily_update').setLevel(logging.INFO) # Example
except ImportError:
logging.error("Failed to import daily_update.py. The daily update task will not run.")
run_daily_update = None # Define as None if import fails
# ---
app = Flask(__name__) # Create app object FIRST
# Define the base persistent storage path (must match other scripts)
PERSISTENT_STORAGE_PATH = "/data" # <-- ADJUST IF YOUR PATH IS DIFFERENT
# Configure Flask app logging (optional but recommended)
# app.logger.setLevel(logging.INFO)
# Allow requests from the Vercel frontend and localhost for development
CORS(app, origins=["http://127.0.0.1:3000", "http://localhost:3000", "https://rag-huggingface.vercel.app"], supports_credentials=True)
# --- Configuration ---
# Point to index/map files in persistent storage
INDEX_FILE = os.path.join(PERSISTENT_STORAGE_PATH, "index.faiss")
MAP_FILE = os.path.join(PERSISTENT_STORAGE_PATH, "index_to_metadata.pkl")
EMBEDDING_MODEL = 'all-mpnet-base-v2'
# Point to model data JSON in persistent storage
MODEL_DATA_DIR = os.path.join(PERSISTENT_STORAGE_PATH, "model_data_json")
# ---
# --- Global variables for resources ---
faiss = None
pickle = None
index = None
index_to_metadata = None
model = None
SentenceTransformer = None # Keep track of the imported class
RESOURCES_LOADED = False
# ---
def load_resources():
"""Loads all necessary resources (model, index, map) only once."""
global faiss, pickle, index, index_to_metadata, model, SentenceTransformer, RESOURCES_LOADED
if RESOURCES_LOADED: # Prevent re-loading
print("Resources already loaded.")
return
print("Loading resources...")
try:
# Deferred Import of Faiss and Pickle inside the function
print("Importing Faiss and Pickle...")
import faiss as faiss_local
import pickle as pickle_local
faiss = faiss_local
pickle = pickle_local
print("Faiss and Pickle imported successfully.")
# Load Sentence Transformer Model
print(f"Importing SentenceTransformer and loading model: {EMBEDDING_MODEL}")
from sentence_transformers import SentenceTransformer as SentenceTransformer_local
SentenceTransformer = SentenceTransformer_local # Store the class globally if needed elsewhere
model_local = SentenceTransformer(EMBEDDING_MODEL)
model = model_local # Assign to global variable
print("Sentence transformer model loaded successfully.")
# Load FAISS Index
# index_path = os.path.join(os.path.dirname(__file__), INDEX_FILE) # Old path
index_path = INDEX_FILE # Use configured path
print(f"Loading FAISS index from: {index_path}")
if not os.path.exists(index_path):
raise FileNotFoundError(f"FAISS index file not found at {index_path}")
index_local = faiss.read_index(index_path)
index = index_local # Assign to global variable
print("FAISS index loaded successfully.")
# Load Index-to-Metadata Map
# map_path = os.path.join(os.path.dirname(__file__), MAP_FILE) # Old path
map_path = MAP_FILE # Use configured path
print(f"Loading index-to-Metadata map from: {map_path}")
if not os.path.exists(map_path):
raise FileNotFoundError(f"Metadata map file not found at {map_path}")
with open(map_path, 'rb') as f:
index_to_metadata_local = pickle.load(f)
index_to_metadata = index_to_metadata_local # Assign to global variable
print("Index-to-Metadata map loaded successfully.")
print("All resources loaded successfully.")
RESOURCES_LOADED = True
except FileNotFoundError as fnf_error:
print(f"Error: {fnf_error}")
print(f"Please ensure {os.path.basename(INDEX_FILE)} and {os.path.basename(MAP_FILE)} exist in the persistent storage directory ({PERSISTENT_STORAGE_PATH}).")
print("You might need to run the update process first or manually place initial files there.")
RESOURCES_LOADED = False # Keep as False
except ImportError as import_error:
print(f"Import Error loading resources: {import_error}")
traceback.print_exc()
RESOURCES_LOADED = False
except Exception as e:
print(f"Unexpected error loading resources: {e}")
traceback.print_exc() # Print full traceback for loading errors
RESOURCES_LOADED = False # Keep as False
# --- Load resources when the module is imported ---
# This should be executed only once by Gunicorn when it imports 'app:app'
load_resources()
# ---
# --- Background Update Task ---
UPDATE_INTERVAL_HOURS = 24 # Check every 24 hours
UPDATE_TIME = "02:00" # Time to run the update (24-hour format)
def run_update_task():
"""Wrapper function to run the daily update and handle errors."""
if run_daily_update is None:
logging.warning("run_daily_update function not available (import failed). Skipping task.")
return
logging.info(f"Background task: Starting daily update check (scheduled for {UPDATE_TIME})...")
try:
# Make sure the DEEPSEEK_API_KEY is set before running
if not os.getenv("DEEPSEEK_API_KEY"):
logging.error("Background task: DEEPSEEK_API_KEY not set. Daily update cannot run.")
return # Don't run if key is missing
run_daily_update() # Call the main function from daily_update.py
logging.info("Background task: Daily update process finished.")
except Exception as e:
logging.error(f"Background task: Error during daily update execution: {e}")
logging.error(traceback.format_exc())
def background_scheduler():
"""Runs the scheduler loop in a background thread."""
logging.info(f"Background scheduler started. Will run update task daily around {UPDATE_TIME}.")
if run_daily_update is None:
logging.error("Background scheduler: daily_update.py could not be imported. Scheduler will not run tasks.")
return # Stop the thread if the core function isn't available
# Schedule the job
# schedule.every(UPDATE_INTERVAL_HOURS).hours.do(run_update_task) # Alternative: run every X hours
schedule.every().day.at(UPDATE_TIME).do(run_update_task)
logging.info(f"Scheduled daily update task for {UPDATE_TIME}.")
# --- Run once immediately on startup ---
logging.info("Background task: Running initial update check on startup...")
run_update_task() # Call the task function directly
logging.info("Background task: Initial update check finished.")
# ---
while True:
schedule.run_pending()
time.sleep(60) # Check every 60 seconds if a task is due
# Start the background scheduler thread only if this is the main process
# This check helps prevent duplicate schedulers when using workers (like Gunicorn)
# Note: This might not be perfectly reliable with all WSGI servers/configs.
# Consider using a more robust method for ensuring single execution if needed (e.g., file lock, external process manager)
if os.environ.get("WERKZEUG_RUN_MAIN") == "true" or os.environ.get("FLASK_ENV") != "development":
# Start only in main Werkzeug process OR if not in Flask development mode (like production with Gunicorn)
# Check if the function is available before starting thread
if run_daily_update is not None:
scheduler_thread = threading.Thread(target=background_scheduler, daemon=True)
scheduler_thread.start()
logging.info("Background scheduler thread started.")
else:
logging.warning("Background scheduler thread NOT started because daily_update.py failed to import.")
else:
logging.info("Skipping background scheduler start in Werkzeug reloader process.")
# --- End Background Update Task ---
@app.route('/search', methods=['POST'])
def search():
"""Handles search requests, embedding the query and searching the FAISS index."""
# Check if resources are loaded at the beginning of the request
if not RESOURCES_LOADED:
# You could attempt to reload here, but it's often better to fail
# if the initial load failed, as something is wrong with the environment/files.
print("Error: Search request received, but resources are not loaded.")
return jsonify({"error": "Backend resources not initialized. Check server logs."}), 500
# Check for necessary components loaded by load_resources
if model is None or index is None or index_to_metadata is None or faiss is None:
print("Error: Search request received, but some core components (model, index, map, faiss) are None.")
return jsonify({"error": "Backend components inconsistency. Check server logs."}), 500
data = request.get_json()
if not data or 'query' not in data:
return jsonify({"error": "Missing 'query' in request body"}), 400
query = data['query']
top_k = data.get('top_k', 10) # Default to top 10
try:
# Embed the query
# Ensure model is not None (already checked above, but good practice)
if model is None:
return jsonify({"error": "Model not loaded."}), 500
query_embedding = model.encode([query], convert_to_numpy=True).astype('float32')
# Search the index
# Ensure index is not None
if index is None:
return jsonify({"error": "Index not loaded."}), 500
distances, indices = index.search(query_embedding, top_k)
# Get the results with full metadata
results = []
if indices.size > 0: # Check if search returned any indices
# Ensure index_to_metadata is not None
if index_to_metadata is None:
print("Error: index_to_metadata is None during result processing.")
return jsonify({"error": "Metadata map not loaded."}), 500
for i in range(len(indices[0])):
idx = indices[0][i]
dist = distances[0][i]
# Check index validity MORE robustly
if idx < 0 or idx not in index_to_metadata:
print(f"Warning: Index {idx} out of bounds or not found in metadata mapping.")
continue # Skip this result
metadata = index_to_metadata[idx].copy() # Copy to avoid mutating original
metadata['distance'] = float(dist) # Add distance to the result dict
# --- Add description from model_data_json ---
model_id = metadata.get('model_id')
description = None
# Use the globally defined MODEL_DATA_DIR pointing to persistent storage
if model_id and MODEL_DATA_DIR:
filename = model_id.replace('/', '_') + '.json'
filepath = os.path.join(MODEL_DATA_DIR, filename)
if os.path.exists(filepath):
try:
with open(filepath, 'r', encoding='utf-8') as f:
model_data = json.load(f)
description = model_data.get('description')
except Exception as e:
print(f"Error reading description file {filepath}: {e}")
# Keep description as None
# else: # Optional: Log if description file doesn't exist
# print(f"Description file not found: {filepath}")
metadata['description'] = description or 'No description available.'
# ---
results.append(metadata) # Append the whole metadata dict
else:
print("Warning: FAISS search returned empty indices.")
return jsonify({"results": results})
except Exception as e:
print(f"Error during search: {e}")
traceback.print_exc() # Print full traceback for search errors
return jsonify({"error": "An error occurred during search."}), 500
# The if __name__ == '__main__': block remains removed.