SafeSpace-AI / app.py
Jaamie's picture
Upload app.py
571c381 verified
# -*- coding: utf-8 -*-
"""app.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1BmTzCgYHoIX81jKTqf4ImJaKRRbxgoTS
"""
import os
import csv
import pandas as pd
import plotly.express as px
from datetime import datetime
import torch
import faiss
import numpy as np
import gradio as gr
from sklearn.metrics.pairwise import cosine_similarity
# from google.colab import drive
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
from peft import PeftModel
from huggingface_hub import login
from transformers import pipeline as hf_pipeline
from fpdf import FPDF
import uuid
import textwrap
from dotenv import load_dotenv
import shutil
try:
import whisper
except ImportError:
os.system("pip install -U openai-whisper")
import whisper
# Load Whisper model here
whisper_model = whisper.load_model("base")
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
login(token=hf_token)
# Mount Google Drive
#drive.mount('/content/drive')
# -------------------------------
# πŸ”§ Configuration
# -------------------------------
base_model_path = "google/gemma-2-9b-it"
#peft_model_path = "Jaamie/gemma-mental-health-qlora"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
embedding_model_bge = "BAAI/bge-base-en-v1.5"
#save_path_bge = "./models/bge-base-en-v1.5"
faiss_index_path = "./qa_faiss_embedding.index"
chunked_text_path = "./chunked_text_RAG_text.txt"
READER_MODEL_NAME = "google/gemma-2-9b-it"
#READER_MODEL_NAME = "google/gemma-2b-it"
log_file_path = "./diagnosis_logs.csv"
feedback_file_path = "./feedback_logs.csv"
# -------------------------------
# πŸ”§ Logging setup
# -------------------------------
if not os.path.exists(log_file_path):
with open(log_file_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow(["timestamp", "user_id", "input_type", "query", "diagnosis", "confidence_score", "status"])
# -------------------------------
# πŸ”§ Feedback setup
# -------------------------------
if not os.path.exists(feedback_file_path):
with open(feedback_file_path, "w", newline="", encoding="utf-8") as f:
writer = csv.writer(f)
writer.writerow([
"feedback_id", "timestamp", "user_id", "input_type", "query",
"diagnosis", "status", "feedback"
])
# Ensure directory exists
#os.makedirs(save_path_bge, exist_ok=True)
# -------------------------------
# πŸ”§ Model setup
# -------------------------------
# Load Sentence Transformer Model
# if not os.path.exists(os.path.join(save_path_bge, "config.json")):
# print("Saving model to Google Drive...")
# embedding_model = SentenceTransformer(embedding_model_bge)
# embedding_model.save(save_path_bge)
# print("Model saved successfully!")
# else:
# print("Loading model from Google Drive...")
# device = 'cuda' if torch.cuda.is_available() else 'cpu'
# embedding_model = SentenceTransformer(save_path_bge, device=device)
embedding_model = SentenceTransformer(embedding_model_bge, device=device)
print("βœ… BGE Embedding model loaded from Hugging Face.")
# Load FAISS Index
faiss_index = faiss.read_index(faiss_index_path)
print("FAISS index loaded successfully!")
# Load chunked text
def load_chunked_text():
with open(chunked_text_path, "r", encoding="utf-8") as f:
return f.read().split("\n\n---\n\n")
chunked_text = load_chunked_text()
print(f"Loaded {len(chunked_text)} text chunks.")
# loading model for emotion classifier
emotion_result = {}
emotion_classifier = hf_pipeline("text-classification", model="nateraw/bert-base-uncased-emotion")
# -------------------------------
# 🧠 Load base model + LoRA adapter
# -------------------------------
# base_model = AutoModelForCausalLM.from_pretrained(
# base_model_path,
# torch_dtype=torch.float16,
# device_map="auto" # Use accelerate for smart placement
# )
# # Load the LoRA adapter on top of the base model
# diagnosis_model = PeftModel.from_pretrained(
# base_model,
# peft_model_path
# ).to(device)
# # Load tokenizer from the same fine-tuned repo
# diagnosis_tokenizer = AutoTokenizer.from_pretrained(peft_model_path)
# # Set model to evaluation mode
# diagnosis_model.eval()
# print("βœ… Model & tokenizer loaded successfully.")
# # Create text-generation pipeline WITHOUT `device` arg
# READER_LLM = pipeline(
# model=diagnosis_model,
# tokenizer=diagnosis_tokenizer,
# task="text-generation",
# do_sample=True,
# temperature=0.2,
# repetition_penalty=1.1,
# return_full_text=False,
# max_new_tokens=500
# )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME)
#model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME).to(device)
# model_id = "mistralai/Mistral-7B-Instruct-v0.1"
# #model_id = "TheBloke/Gemma-2-7B-IT-GGUF"
# tokenizer = AutoTokenizer.from_pretrained(model_id)
# model = AutoModelForCausalLM.from_pretrained(
# model_id,
# torch_dtype=torch.float16,
# device_map="auto",
# ).to(device)
READER_LLM = pipeline(
model=model,
tokenizer=tokenizer,
task="text-generation",
do_sample=True,
temperature=0.2,
repetition_penalty=1.1,
return_full_text=False,
max_new_tokens=500,
#device=device,
)
# -------------------------------
# πŸ”§ Whisper Model Setup
# -------------------------------
def process_whisper_query(audio):
try:
audio_data = whisper.load_audio(audio)
audio_data = whisper.pad_or_trim(audio_data)
mel = whisper.log_mel_spectrogram(audio_data).to(whisper_model.device)
result = whisper_model.decode(mel, whisper.DecodingOptions(fp16=False))
transcribed_text = result.text.strip()
response, download_path = process_query(transcribed_text, input_type="voice")
return response, download_path
except Exception as e:
return f"⚠️ Error processing audio: {str(e)}", None
def extract_diagnosis(response_text: str) -> str:
for line in response_text.splitlines():
if "Diagnosed Mental Disorder" in line:
return line.split(":")[-1].strip()
return "Unknown"
# calculating the correctness of the answer - Hallucination
def calculate_rag_confidence(query_embedding, top_k_docs_embeddings, generation_logprobs=None):
"""
Combines retriever and generation signals to compute a confidence score.
Args:
query_embedding (np.ndarray): Embedding vector of the user query (shape: [1, dim]).
top_k_docs_embeddings (np.ndarray): Embedding matrix of top-k retrieved documents (shape: [k, dim]).
generation_logprobs (list, optional): List of logprobs for generated tokens.
Returns:
float: Final confidence score (0 to 1).
"""
retriever_similarities = cosine_similarity(query_embedding, top_k_docs_embeddings)
retriever_confidence = float(np.max(retriever_similarities))
if generation_logprobs:
gen_confidence = float(np.exp(np.mean(generation_logprobs)))
else:
gen_confidence = 0.0 # fallback if unavailable
alpha, beta = 0.6, 0.4
final_confidence = alpha * retriever_confidence + beta * gen_confidence
return round(final_confidence, 4)
# Main Process
def process_query(user_query, input_type="text"):
# Embed the query
query_embedding = embedding_model.encode(user_query, normalize_embeddings=True)
query_embedding = np.array([query_embedding], dtype=np.float32)
# Search FAISS index
k = 5 # Retrieve top 5 relevant docs
distances, indices = faiss_index.search(query_embedding, k)
retrieved_docs = [chunked_text[i] for i in indices[0]]
# Construct context
context = "\nExtracted documents:\n" + "".join([f"Document {i}:::\n{doc}\n" for i, doc in enumerate(retrieved_docs)])
# Detect emotion
emotion_result = emotion_classifier(user_query)[0]
print(f"Detected emotion: {emotion_result}")
emotion = emotion_result['label']
value = round(emotion_result['score'], 2)
# Define RAG prompt
prompt_in_chat_format = [
{"role": "user", "content": f"""
You are an AI assistant specialized in diagnosing mental disorders in humans.
Using the information contained in the context, answer the question comprehensively.
The **Diagnosed Mental Disorder** should be only one from the list provided.
[Normal, Depression, Suicidal, Anxiety, Stress, Bi-Polar, Personality Disorder]
Your response must include:
1. **Diagnosed Mental Disorder**
2. **Detected emotion** {emotion}
3. **Intensity of emotion** {value}
3. **Matching Symptoms** from the context
4. **Personalized Treatment**
5. **Helpline Numbers**
6. **Source Link** (if applicable)
Make sure to provide a comprehensive and accurate diagnosis and explain the personalised treatment in detail.
If a disorder cannot be determined, return **Diagnosed Mental Disorder** as "Unknown".
---
Context:
{context}
Question: {user_query}"""},
{"role": "assistant", "content": ""},
]
RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template(
prompt_in_chat_format, tokenize=False, add_generation_prompt=True
)
# Generate response
#answer = READER_LLM(RAG_PROMPT_TEMPLATE)[0]["generated_text"]
try:
response = READER_LLM(RAG_PROMPT_TEMPLATE)
# print("πŸ” Raw LLM output:", response)
answer = response[0]["generated_text"] if response and "generated_text" in response[0] else "⚠️ No output generated."
except Exception as e:
print("❌ Error during generation:", e)
answer = "⚠️ An error occurred while generating the response."
# Get embeddings of retrieved docs
retrieved_doc_embeddings = embedding_model.encode(retrieved_docs, normalize_embeddings=True)
retrieved_doc_embeddings = np.array(retrieved_doc_embeddings, dtype=np.float32)
# Calculate RAG-based confidence
confidence_score = calculate_rag_confidence(query_embedding, retrieved_doc_embeddings)
# Add to response
answer += f"\n\n🧭 Accuracy & Closeness of the Answer: {confidence_score:.2f}"
answer += "\n\n*Derived from semantic similarity and generation certainty."
# Extracting diagnosis
diagnosis = extract_diagnosis(answer)
status = "fallback" if diagnosis.lower() == "unknown" else "success"
# Log interaction
log_query(input_type=input_type, query=user_query, diagnosis=diagnosis, confidence_score=confidence_score, status=status)
download_path = create_summary_txt(answer)
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
user_id = session_data["latest"]["user_id"]
# Prepend to the answer string
answer_header = f"🧾 Session ID: {user_id}\nπŸ“… Timestamp: {timestamp}\n\n"
return answer_header + answer, download_path
#return answer, download_path
# Dashboard Interface
def diagnosis_dashboard():
try:
df = pd.read_csv(log_file_path)
if df.empty:
return "No data logged yet."
# Filter out unknown or fallback cases if needed
df = df[df["diagnosis"].notna()]
df = df[df["diagnosis"].str.lower() != "unknown"]
# Diagnosis frequency
diagnosis_counts = df["diagnosis"].value_counts().reset_index()
diagnosis_counts.columns = ["Diagnosis", "Count"]
# Create bar chart
fig = px.bar(
diagnosis_counts,
x="Diagnosis",
y="Count",
color="Diagnosis",
title="πŸ“Š Mental Health Diagnosis Distribution",
text_auto=True
)
fig.update_layout(showlegend=False)
return fig
except Exception as e:
return f"⚠️ Error loading dashboard: {str(e)}"
# For logs functionality
# def log_query(input_type, query, diagnosis, confidence_score, status):
# with open(log_file_path, "a", newline="", encoding="utf-8") as f:
# writer = csv.writer(f, quoting=csv.QUOTE_ALL)
# writer.writerow([
# datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
# input_type.replace('"', '""'),
# query.replace('"', '""'),
# diagnosis.replace('"', '""'),
# str(confidence_score),
# status
# ])
session_data = {}
def log_query(input_type, query, diagnosis, confidence_score, status):
user_id = f"SSuser_ID_{uuid.uuid4().hex[:8]}"
# Store in-memory session data for feedback use
session_data["latest"] = {
"user_id": user_id,
"input_type": input_type,
"query": query,
"diagnosis": diagnosis,
"confidence_score": confidence_score,
"status": status
}
with open(log_file_path, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f, quoting=csv.QUOTE_ALL)
writer.writerow([
str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")),
str(user_id),
str(input_type).replace('"', '""'),
str(query).replace('"', '""'),
str(diagnosis).replace('"', '""'),
str(confidence_score),
str(status)
])
def show_logs():
try:
df = pd.read_csv(log_file_path)
return df.tail(100)
except Exception as e:
return f"⚠️ Error: {e}"
# def create_summary_pdf(text, filename_prefix="diagnosis_report"):
# try:
# filename = f"{filename_prefix}_{uuid.uuid4().hex[:6]}.pdf"
# filepath = os.path.join(".", filename) # Save in current directory
# pdf = FPDF()
# pdf.add_page()
# pdf.set_font("Arial", style='B', size=14)
# pdf.cell(200, 10, txt="🧠 Mental Health Diagnosis Report", ln=True, align='C')
# pdf.set_font("Arial", size=12)
# pdf.ln(10)
# wrapped = textwrap.wrap(text, width=90)
# for line in wrapped:
# pdf.cell(200, 10, txt=line, ln=True)
# pdf.output(filepath)
# print(f"βœ… PDF created at: {filepath}")
# return filepath
# except Exception as e:
# print(f"❌ Error creating PDF: {e}")
# return None
def create_summary_txt(text, filename_prefix="diagnosis_report"):
filename = f"{filename_prefix}_{uuid.uuid4().hex[:6]}.txt"
with open(filename, "w", encoding="utf-8") as f:
f.write(text)
print(f"βœ… TXT report created: {filename}")
return filename
# πŸ“₯ Feedback
# feedback_data = []
# def submit_feedback(feedback, input_type, query, diagnosis, confidence_score, status):
# feedback_id = str(uuid.uuid4())
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# with open(feedback_file_path, "a", newline="", encoding="utf-8") as f:
# writer = csv.writer(f, quoting=csv.QUOTE_ALL)
# writer.writerow([
# feedback_id,
# timestamp,
# input_type.replace('"', '""'),
# query.replace('"', '""'),
# diagnosis.replace('"', '""'),
# str(confidence_score),
# status,
# feedback.replace('"', '""')
# ])
# return f"βœ… Feedback received! Your Feedback ID: {feedback_id}"
def submit_feedback(feedback):
# if "latest" not in session_data:
# return "⚠️ No diagnosis found for this session. Please get a diagnosis first."
user_info = session_data["latest"]
feedback_id = f"fb_{uuid.uuid4().hex[:8]}"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
with open(feedback_file_path, "a", newline="", encoding="utf-8") as f:
writer = csv.writer(f, quoting=csv.QUOTE_ALL)
writer.writerow([
feedback_id,
timestamp,
user_info["user_id"],
user_info["input_type"],
user_info["query"],
user_info["diagnosis"],
user_info["status"],
feedback.replace('"', '""')
])
return f"βœ… Feedback received! Your Feedback ID: {feedback_id}"
def download_feedback_log():
return feedback_file_path
# def send_email_report(to_email, response):
# response = resend.Emails.send({
# "from": "MentalBot <[email protected]>",
# "to": [to_email],
# "subject": "🧠 Your Personalized Mental Health Report",
# "text": response
# })
# return "βœ… Diagnosis report sent to your email!" if response.get("id") else "⚠️ Failed to send email."
# For pdf
# def unified_handler(audio, text):
# if audio:
# response, download_path = process_whisper_query(audio)
# else:
# response, download_path = process_query(text, input_type="text")
# # Ensure download path is valid
# # if not (download_path and os.path.exists(download_path)):
# # print("❌ PDF not found or failed to generate.")
# # return response, None
# if download_path and os.path.exists(download_path):
# return response, download_path
# else:
# print("❌ PDF not found or failed to generate.")
# return response, None
# for text doc download
def unified_handler(audio, text):
if audio:
response, _ = process_whisper_query(audio)
else:
response, _ = process_query(text, input_type="text")
download_path = create_summary_txt(response)
return response, download_path
# Gradio UI
main_assistant_tab = gr.Interface(
fn=unified_handler,
inputs=[
gr.Audio(type="filepath", label="πŸŽ™ Speak your concern"),
gr.Textbox(lines=2, placeholder="Or type your mental health concern here...")
],
outputs=[
gr.Textbox(label="🧠 Personalized Diagnosis", lines=15, show_copy_button=True),
gr.File(label="πŸ“₯ Download Diagnosis Report")
],
title="🧠 SafeSpace AI",
description="πŸ’™ *We care for you.*\n\nSpeak or type your concern to receive AI-powered mental health insights. Get your report emailed or download it as a file."
)
dashboard_tab = gr.Interface(
fn=diagnosis_dashboard,
inputs=[],
outputs=gr.Plot(label="πŸ“Š Diagnosis Distribution"),
title="πŸ“Š Usage Dashboard"
)
logs_tab = gr.Interface(
fn=show_logs,
inputs=[],
outputs=gr.Dataframe(label="πŸ“„ Diagnosis Logs (Latest 100 entries)"),
title="πŸ“„ Logs"
)
feedback_tab = gr.Interface(
fn=submit_feedback,
inputs=[gr.Textbox(label="πŸ“ Share your thoughts")],
outputs="text",
title="πŸ“ Submit Feedback"
)
feedback_download_tab = gr.Interface(
fn=download_feedback_log,
inputs=[],
outputs=gr.File(label="πŸ“₯ Download All Feedback Logs"),
title="πŸ“‚ Download Feedback CSV"
)
agent_tab = gr.Interface(
fn=lambda: "",
inputs=[],
outputs=gr.HTML(
"""<button onclick="window.open('https://jaamie-mental-health-agent.hf.space', '_blank')"
style='padding:10px 20px; font-size:16px; background-color:#4CAF50; color:white; border:none; border-radius:5px;'>
🧠 Launch Agent SafeSpace 001
</button>"""
),
title="πŸ€– Agent SafeSpace 001"
)
# Add to your tab list
app = gr.TabbedInterface(
interface_list=[
main_assistant_tab,
dashboard_tab,
logs_tab,
feedback_tab,
feedback_download_tab,
agent_tab
],
tab_names=[
"🧠 Assistant",
"πŸ“Š Dashboard",
"πŸ“„ Logs",
"πŸ“ Feedback",
"πŸ“‚ Feedback CSV",
"πŸ€– Agent 001"
]
)
#app.launch(share=True)
print("πŸš€ SafeSpace AI is live!")
# Launch the Gradio App
if __name__ == "__main__":
app.launch()