Spaces:
Sleeping
Sleeping
import os | |
import json | |
import uuid | |
import numpy as np | |
from datetime import datetime | |
from flask import Flask, request, jsonify, send_from_directory | |
from flask_cors import CORS | |
from werkzeug.utils import secure_filename | |
import google.generativeai as genai | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
from transformers import pipeline | |
import faiss | |
import markdown | |
import re | |
# Configuration | |
GEMINI_API_KEY = "AIzaSyDGpNmvskXEAeOH6hG_BtT8GR043tMREYk" # Replace with actual API key | |
genai.configure(api_key=GEMINI_API_KEY) | |
# Set Hugging Face cache directory explicitly | |
CACHE_DIR = "/tmp/huggingface" | |
os.environ["HF_HOME"] = CACHE_DIR | |
os.environ["HF_DATASETS_CACHE"] = f"{CACHE_DIR}/datasets" | |
os.environ["TRANSFORMERS_CACHE"] = f"{CACHE_DIR}/transformers" | |
os.environ["HF_HUB_CACHE"] = f"{CACHE_DIR}/hub" | |
# Ensure cache directories exist | |
for path in [os.environ["HF_DATASETS_CACHE"], os.environ["TRANSFORMERS_CACHE"], os.environ["HF_HUB_CACHE"]]: | |
os.makedirs(path, exist_ok=True) | |
# Initialize Flask app | |
app = Flask(__name__, static_folder="../frontend", static_url_path="") | |
CORS(app) | |
# RAG Model Initialization | |
print("\U0001F680 Initializing RAG System...") | |
# Load medical guidelines dataset with explicit cache directory | |
print("\U0001F4C2 Loading dataset...") | |
dataset = load_dataset("epfl-llm/guidelines", split="train", cache_dir="/tmp/huggingface/datasets") | |
TITLE_COL = "title" | |
CONTENT_COL = "clean_text" | |
# Initialize models | |
print("\U0001F916 Loading AI models...") | |
embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad") | |
# Build FAISS index | |
print("\U0001F50D Building FAISS index...") | |
def embed_text(batch): | |
combined_texts = [ | |
f"{title} {content[:200]}" | |
for title, content in zip(batch[TITLE_COL], batch[CONTENT_COL]) | |
] | |
return {"embeddings": embedder.encode(combined_texts, show_progress_bar=False)} | |
dataset = dataset.map(embed_text, batched=True, batch_size=32) | |
dataset.add_faiss_index(column="embeddings") | |
# Flask Endpoints | |
def serve_index(): | |
return send_from_directory(app.static_folder, "index.html") | |
def serve_static(path): | |
return send_from_directory(app.static_folder, path) | |
# Run the app only if it's not under Gunicorn | |
if __name__ == "__main__": | |
app.run(host="0.0.0.0", port=5000, debug=True) | |