my-clinical-ai / app.py
badal-12's picture
Update app.py
6bf193c verified
import os
import json
import uuid
import numpy as np
from datetime import datetime
from flask import Flask, request, jsonify, send_from_directory
from flask_cors import CORS
from werkzeug.utils import secure_filename
import google.generativeai as genai
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from transformers import pipeline
import faiss
import markdown
import re
# Configuration
GEMINI_API_KEY = "AIzaSyDGpNmvskXEAeOH6hG_BtT8GR043tMREYk" # Replace with actual API key
genai.configure(api_key=GEMINI_API_KEY)
# Set Hugging Face cache directory explicitly
CACHE_DIR = "/tmp/huggingface"
os.environ["HF_HOME"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = f"{CACHE_DIR}/datasets"
os.environ["TRANSFORMERS_CACHE"] = f"{CACHE_DIR}/transformers"
os.environ["HF_HUB_CACHE"] = f"{CACHE_DIR}/hub"
# Ensure cache directories exist
for path in [os.environ["HF_DATASETS_CACHE"], os.environ["TRANSFORMERS_CACHE"], os.environ["HF_HUB_CACHE"]]:
os.makedirs(path, exist_ok=True)
# Initialize Flask app
app = Flask(__name__, static_folder="../frontend", static_url_path="")
CORS(app)
# RAG Model Initialization
print("\U0001F680 Initializing RAG System...")
# Load medical guidelines dataset with explicit cache directory
print("\U0001F4C2 Loading dataset...")
dataset = load_dataset("epfl-llm/guidelines", split="train", cache_dir="/tmp/huggingface/datasets")
TITLE_COL = "title"
CONTENT_COL = "clean_text"
# Initialize models
print("\U0001F916 Loading AI models...")
embedder = SentenceTransformer("all-MiniLM-L6-v2")
qa_pipeline = pipeline("question-answering", model="distilbert-base-cased-distilled-squad")
# Build FAISS index
print("\U0001F50D Building FAISS index...")
def embed_text(batch):
combined_texts = [
f"{title} {content[:200]}"
for title, content in zip(batch[TITLE_COL], batch[CONTENT_COL])
]
return {"embeddings": embedder.encode(combined_texts, show_progress_bar=False)}
dataset = dataset.map(embed_text, batched=True, batch_size=32)
dataset.add_faiss_index(column="embeddings")
# Flask Endpoints
@app.route("/")
def serve_index():
return send_from_directory(app.static_folder, "index.html")
@app.route("/<path:path>")
def serve_static(path):
return send_from_directory(app.static_folder, path)
# Run the app only if it's not under Gunicorn
if __name__ == "__main__":
app.run(host="0.0.0.0", port=5000, debug=True)