Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""app.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1BmTzCgYHoIX81jKTqf4ImJaKRRbxgoTS | |
""" | |
import os | |
import csv | |
import pandas as pd | |
import plotly.express as px | |
from datetime import datetime | |
import torch | |
import faiss | |
import numpy as np | |
import gradio as gr | |
from sklearn.metrics.pairwise import cosine_similarity | |
# from google.colab import drive | |
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer | |
from sentence_transformers import SentenceTransformer | |
from peft import PeftModel | |
from huggingface_hub import login | |
from transformers import pipeline as hf_pipeline | |
from fpdf import FPDF | |
import uuid | |
import textwrap | |
from dotenv import load_dotenv | |
import shutil | |
try: | |
import whisper | |
except ImportError: | |
os.system("pip install -U openai-whisper") | |
import whisper | |
# Load Whisper model here | |
whisper_model = whisper.load_model("base") | |
load_dotenv() | |
hf_token = os.getenv("HF_TOKEN") | |
login(token=hf_token) | |
# Mount Google Drive | |
#drive.mount('/content/drive') | |
# ------------------------------- | |
# π§ Configuration | |
# ------------------------------- | |
base_model_path = "google/gemma-2-9b-it" | |
#peft_model_path = "Jaamie/gemma-mental-health-qlora" | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
embedding_model_bge = "BAAI/bge-base-en-v1.5" | |
#save_path_bge = "./models/bge-base-en-v1.5" | |
faiss_index_path = "./qa_faiss_embedding.index" | |
chunked_text_path = "./chunked_text_RAG_text.txt" | |
READER_MODEL_NAME = "google/gemma-2-9b-it" | |
#READER_MODEL_NAME = "google/gemma-2b-it" | |
log_file_path = "./diagnosis_logs.csv" | |
feedback_file_path = "./feedback_logs.csv" | |
# ------------------------------- | |
# π§ Logging setup | |
# ------------------------------- | |
if not os.path.exists(log_file_path): | |
with open(log_file_path, "w", newline="", encoding="utf-8") as f: | |
writer = csv.writer(f) | |
writer.writerow(["timestamp", "user_id", "input_type", "query", "diagnosis", "confidence_score", "status"]) | |
# ------------------------------- | |
# π§ Feedback setup | |
# ------------------------------- | |
if not os.path.exists(feedback_file_path): | |
with open(feedback_file_path, "w", newline="", encoding="utf-8") as f: | |
writer = csv.writer(f) | |
writer.writerow([ | |
"feedback_id", "timestamp", "user_id", "input_type", "query", | |
"diagnosis", "status", "feedback" | |
]) | |
# Ensure directory exists | |
#os.makedirs(save_path_bge, exist_ok=True) | |
# ------------------------------- | |
# π§ Model setup | |
# ------------------------------- | |
# Load Sentence Transformer Model | |
# if not os.path.exists(os.path.join(save_path_bge, "config.json")): | |
# print("Saving model to Google Drive...") | |
# embedding_model = SentenceTransformer(embedding_model_bge) | |
# embedding_model.save(save_path_bge) | |
# print("Model saved successfully!") | |
# else: | |
# print("Loading model from Google Drive...") | |
# device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# embedding_model = SentenceTransformer(save_path_bge, device=device) | |
embedding_model = SentenceTransformer(embedding_model_bge, device=device) | |
print("β BGE Embedding model loaded from Hugging Face.") | |
# Load FAISS Index | |
faiss_index = faiss.read_index(faiss_index_path) | |
print("FAISS index loaded successfully!") | |
# Load chunked text | |
def load_chunked_text(): | |
with open(chunked_text_path, "r", encoding="utf-8") as f: | |
return f.read().split("\n\n---\n\n") | |
chunked_text = load_chunked_text() | |
print(f"Loaded {len(chunked_text)} text chunks.") | |
# loading model for emotion classifier | |
emotion_result = {} | |
emotion_classifier = hf_pipeline("text-classification", model="nateraw/bert-base-uncased-emotion") | |
# ------------------------------- | |
# π§ Load base model + LoRA adapter | |
# ------------------------------- | |
# base_model = AutoModelForCausalLM.from_pretrained( | |
# base_model_path, | |
# torch_dtype=torch.float16, | |
# device_map="auto" # Use accelerate for smart placement | |
# ) | |
# # Load the LoRA adapter on top of the base model | |
# diagnosis_model = PeftModel.from_pretrained( | |
# base_model, | |
# peft_model_path | |
# ).to(device) | |
# # Load tokenizer from the same fine-tuned repo | |
# diagnosis_tokenizer = AutoTokenizer.from_pretrained(peft_model_path) | |
# # Set model to evaluation mode | |
# diagnosis_model.eval() | |
# print("β Model & tokenizer loaded successfully.") | |
# # Create text-generation pipeline WITHOUT `device` arg | |
# READER_LLM = pipeline( | |
# model=diagnosis_model, | |
# tokenizer=diagnosis_tokenizer, | |
# task="text-generation", | |
# do_sample=True, | |
# temperature=0.2, | |
# repetition_penalty=1.1, | |
# return_full_text=False, | |
# max_new_tokens=500 | |
# ) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
tokenizer = AutoTokenizer.from_pretrained(READER_MODEL_NAME) | |
#model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained(READER_MODEL_NAME).to(device) | |
# model_id = "mistralai/Mistral-7B-Instruct-v0.1" | |
# #model_id = "TheBloke/Gemma-2-7B-IT-GGUF" | |
# tokenizer = AutoTokenizer.from_pretrained(model_id) | |
# model = AutoModelForCausalLM.from_pretrained( | |
# model_id, | |
# torch_dtype=torch.float16, | |
# device_map="auto", | |
# ).to(device) | |
READER_LLM = pipeline( | |
model=model, | |
tokenizer=tokenizer, | |
task="text-generation", | |
do_sample=True, | |
temperature=0.2, | |
repetition_penalty=1.1, | |
return_full_text=False, | |
max_new_tokens=500, | |
#device=device, | |
) | |
# ------------------------------- | |
# π§ Whisper Model Setup | |
# ------------------------------- | |
def process_whisper_query(audio): | |
try: | |
audio_data = whisper.load_audio(audio) | |
audio_data = whisper.pad_or_trim(audio_data) | |
mel = whisper.log_mel_spectrogram(audio_data).to(whisper_model.device) | |
result = whisper_model.decode(mel, whisper.DecodingOptions(fp16=False)) | |
transcribed_text = result.text.strip() | |
response, download_path = process_query(transcribed_text, input_type="voice") | |
return response, download_path | |
except Exception as e: | |
return f"β οΈ Error processing audio: {str(e)}", None | |
def extract_diagnosis(response_text: str) -> str: | |
for line in response_text.splitlines(): | |
if "Diagnosed Mental Disorder" in line: | |
return line.split(":")[-1].strip() | |
return "Unknown" | |
# calculating the correctness of the answer - Hallucination | |
def calculate_rag_confidence(query_embedding, top_k_docs_embeddings, generation_logprobs=None): | |
""" | |
Combines retriever and generation signals to compute a confidence score. | |
Args: | |
query_embedding (np.ndarray): Embedding vector of the user query (shape: [1, dim]). | |
top_k_docs_embeddings (np.ndarray): Embedding matrix of top-k retrieved documents (shape: [k, dim]). | |
generation_logprobs (list, optional): List of logprobs for generated tokens. | |
Returns: | |
float: Final confidence score (0 to 1). | |
""" | |
retriever_similarities = cosine_similarity(query_embedding, top_k_docs_embeddings) | |
retriever_confidence = float(np.max(retriever_similarities)) | |
if generation_logprobs: | |
gen_confidence = float(np.exp(np.mean(generation_logprobs))) | |
else: | |
gen_confidence = 0.0 # fallback if unavailable | |
alpha, beta = 0.6, 0.4 | |
final_confidence = alpha * retriever_confidence + beta * gen_confidence | |
return round(final_confidence, 4) | |
# Main Process | |
def process_query(user_query, input_type="text"): | |
# Embed the query | |
query_embedding = embedding_model.encode(user_query, normalize_embeddings=True) | |
query_embedding = np.array([query_embedding], dtype=np.float32) | |
# Search FAISS index | |
k = 5 # Retrieve top 5 relevant docs | |
distances, indices = faiss_index.search(query_embedding, k) | |
retrieved_docs = [chunked_text[i] for i in indices[0]] | |
# Construct context | |
context = "\nExtracted documents:\n" + "".join([f"Document {i}:::\n{doc}\n" for i, doc in enumerate(retrieved_docs)]) | |
# Detect emotion | |
emotion_result = emotion_classifier(user_query)[0] | |
print(f"Detected emotion: {emotion_result}") | |
emotion = emotion_result['label'] | |
value = round(emotion_result['score'], 2) | |
# Define RAG prompt | |
prompt_in_chat_format = [ | |
{"role": "user", "content": f""" | |
You are an AI assistant specialized in diagnosing mental disorders in humans. | |
Using the information contained in the context, answer the question comprehensively. | |
The **Diagnosed Mental Disorder** should be only one from the list provided. | |
[Normal, Depression, Suicidal, Anxiety, Stress, Bi-Polar, Personality Disorder] | |
Your response must include: | |
1. **Diagnosed Mental Disorder** | |
2. **Detected emotion** {emotion} | |
3. **Intensity of emotion** {value} | |
3. **Matching Symptoms** from the context | |
4. **Personalized Treatment** | |
5. **Helpline Numbers** | |
6. **Source Link** (if applicable) | |
Make sure to provide a comprehensive and accurate diagnosis and explain the personalised treatment in detail. | |
If a disorder cannot be determined, return **Diagnosed Mental Disorder** as "Unknown". | |
--- | |
Context: | |
{context} | |
Question: {user_query}"""}, | |
{"role": "assistant", "content": ""}, | |
] | |
RAG_PROMPT_TEMPLATE = tokenizer.apply_chat_template( | |
prompt_in_chat_format, tokenize=False, add_generation_prompt=True | |
) | |
# Generate response | |
#answer = READER_LLM(RAG_PROMPT_TEMPLATE)[0]["generated_text"] | |
try: | |
response = READER_LLM(RAG_PROMPT_TEMPLATE) | |
# print("π Raw LLM output:", response) | |
answer = response[0]["generated_text"] if response and "generated_text" in response[0] else "β οΈ No output generated." | |
except Exception as e: | |
print("β Error during generation:", e) | |
answer = "β οΈ An error occurred while generating the response." | |
# Get embeddings of retrieved docs | |
retrieved_doc_embeddings = embedding_model.encode(retrieved_docs, normalize_embeddings=True) | |
retrieved_doc_embeddings = np.array(retrieved_doc_embeddings, dtype=np.float32) | |
# Calculate RAG-based confidence | |
confidence_score = calculate_rag_confidence(query_embedding, retrieved_doc_embeddings) | |
# Add to response | |
answer += f"\n\nπ§ Accuracy & Closeness of the Answer: {confidence_score:.2f}" | |
answer += "\n\n*Derived from semantic similarity and generation certainty." | |
# Extracting diagnosis | |
diagnosis = extract_diagnosis(answer) | |
status = "fallback" if diagnosis.lower() == "unknown" else "success" | |
# Log interaction | |
log_query(input_type=input_type, query=user_query, diagnosis=diagnosis, confidence_score=confidence_score, status=status) | |
download_path = create_summary_txt(answer) | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
user_id = session_data["latest"]["user_id"] | |
# Prepend to the answer string | |
answer_header = f"π§Ύ Session ID: {user_id}\nπ Timestamp: {timestamp}\n\n" | |
return answer_header + answer, download_path | |
#return answer, download_path | |
# Dashboard Interface | |
def diagnosis_dashboard(): | |
try: | |
df = pd.read_csv(log_file_path) | |
if df.empty: | |
return "No data logged yet." | |
# Filter out unknown or fallback cases if needed | |
df = df[df["diagnosis"].notna()] | |
df = df[df["diagnosis"].str.lower() != "unknown"] | |
# Diagnosis frequency | |
diagnosis_counts = df["diagnosis"].value_counts().reset_index() | |
diagnosis_counts.columns = ["Diagnosis", "Count"] | |
# Create bar chart | |
fig = px.bar( | |
diagnosis_counts, | |
x="Diagnosis", | |
y="Count", | |
color="Diagnosis", | |
title="π Mental Health Diagnosis Distribution", | |
text_auto=True | |
) | |
fig.update_layout(showlegend=False) | |
return fig | |
except Exception as e: | |
return f"β οΈ Error loading dashboard: {str(e)}" | |
# For logs functionality | |
# def log_query(input_type, query, diagnosis, confidence_score, status): | |
# with open(log_file_path, "a", newline="", encoding="utf-8") as f: | |
# writer = csv.writer(f, quoting=csv.QUOTE_ALL) | |
# writer.writerow([ | |
# datetime.now().strftime("%Y-%m-%d %H:%M:%S"), | |
# input_type.replace('"', '""'), | |
# query.replace('"', '""'), | |
# diagnosis.replace('"', '""'), | |
# str(confidence_score), | |
# status | |
# ]) | |
session_data = {} | |
def log_query(input_type, query, diagnosis, confidence_score, status): | |
user_id = f"SSuser_ID_{uuid.uuid4().hex[:8]}" | |
# Store in-memory session data for feedback use | |
session_data["latest"] = { | |
"user_id": user_id, | |
"input_type": input_type, | |
"query": query, | |
"diagnosis": diagnosis, | |
"confidence_score": confidence_score, | |
"status": status | |
} | |
with open(log_file_path, "a", newline="", encoding="utf-8") as f: | |
writer = csv.writer(f, quoting=csv.QUOTE_ALL) | |
writer.writerow([ | |
str(datetime.now().strftime("%Y-%m-%d %H:%M:%S")), | |
str(user_id), | |
str(input_type).replace('"', '""'), | |
str(query).replace('"', '""'), | |
str(diagnosis).replace('"', '""'), | |
str(confidence_score), | |
str(status) | |
]) | |
def show_logs(): | |
try: | |
df = pd.read_csv(log_file_path) | |
return df.tail(100) | |
except Exception as e: | |
return f"β οΈ Error: {e}" | |
# def create_summary_pdf(text, filename_prefix="diagnosis_report"): | |
# try: | |
# filename = f"{filename_prefix}_{uuid.uuid4().hex[:6]}.pdf" | |
# filepath = os.path.join(".", filename) # Save in current directory | |
# pdf = FPDF() | |
# pdf.add_page() | |
# pdf.set_font("Arial", style='B', size=14) | |
# pdf.cell(200, 10, txt="π§ Mental Health Diagnosis Report", ln=True, align='C') | |
# pdf.set_font("Arial", size=12) | |
# pdf.ln(10) | |
# wrapped = textwrap.wrap(text, width=90) | |
# for line in wrapped: | |
# pdf.cell(200, 10, txt=line, ln=True) | |
# pdf.output(filepath) | |
# print(f"β PDF created at: {filepath}") | |
# return filepath | |
# except Exception as e: | |
# print(f"β Error creating PDF: {e}") | |
# return None | |
def create_summary_txt(text, filename_prefix="diagnosis_report"): | |
filename = f"{filename_prefix}_{uuid.uuid4().hex[:6]}.txt" | |
with open(filename, "w", encoding="utf-8") as f: | |
f.write(text) | |
print(f"β TXT report created: {filename}") | |
return filename | |
# π₯ Feedback | |
# feedback_data = [] | |
# def submit_feedback(feedback, input_type, query, diagnosis, confidence_score, status): | |
# feedback_id = str(uuid.uuid4()) | |
# timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
# with open(feedback_file_path, "a", newline="", encoding="utf-8") as f: | |
# writer = csv.writer(f, quoting=csv.QUOTE_ALL) | |
# writer.writerow([ | |
# feedback_id, | |
# timestamp, | |
# input_type.replace('"', '""'), | |
# query.replace('"', '""'), | |
# diagnosis.replace('"', '""'), | |
# str(confidence_score), | |
# status, | |
# feedback.replace('"', '""') | |
# ]) | |
# return f"β Feedback received! Your Feedback ID: {feedback_id}" | |
def submit_feedback(feedback): | |
# if "latest" not in session_data: | |
# return "β οΈ No diagnosis found for this session. Please get a diagnosis first." | |
user_info = session_data["latest"] | |
feedback_id = f"fb_{uuid.uuid4().hex[:8]}" | |
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S") | |
with open(feedback_file_path, "a", newline="", encoding="utf-8") as f: | |
writer = csv.writer(f, quoting=csv.QUOTE_ALL) | |
writer.writerow([ | |
feedback_id, | |
timestamp, | |
user_info["user_id"], | |
user_info["input_type"], | |
user_info["query"], | |
user_info["diagnosis"], | |
user_info["status"], | |
feedback.replace('"', '""') | |
]) | |
return f"β Feedback received! Your Feedback ID: {feedback_id}" | |
def download_feedback_log(): | |
return feedback_file_path | |
# def send_email_report(to_email, response): | |
# response = resend.Emails.send({ | |
# "from": "MentalBot <[email protected]>", | |
# "to": [to_email], | |
# "subject": "π§ Your Personalized Mental Health Report", | |
# "text": response | |
# }) | |
# return "β Diagnosis report sent to your email!" if response.get("id") else "β οΈ Failed to send email." | |
# For pdf | |
# def unified_handler(audio, text): | |
# if audio: | |
# response, download_path = process_whisper_query(audio) | |
# else: | |
# response, download_path = process_query(text, input_type="text") | |
# # Ensure download path is valid | |
# # if not (download_path and os.path.exists(download_path)): | |
# # print("β PDF not found or failed to generate.") | |
# # return response, None | |
# if download_path and os.path.exists(download_path): | |
# return response, download_path | |
# else: | |
# print("β PDF not found or failed to generate.") | |
# return response, None | |
# for text doc download | |
def unified_handler(audio, text): | |
if audio: | |
response, _ = process_whisper_query(audio) | |
else: | |
response, _ = process_query(text, input_type="text") | |
download_path = create_summary_txt(response) | |
return response, download_path | |
# Gradio UI | |
main_assistant_tab = gr.Interface( | |
fn=unified_handler, | |
inputs=[ | |
gr.Audio(type="filepath", label="π Speak your concern"), | |
gr.Textbox(lines=2, placeholder="Or type your mental health concern here...") | |
], | |
outputs=[ | |
gr.Textbox(label="π§ Personalized Diagnosis", lines=15, show_copy_button=True), | |
gr.File(label="π₯ Download Diagnosis Report") | |
], | |
title="π§ SafeSpace AI", | |
description="π *We care for you.*\n\nSpeak or type your concern to receive AI-powered mental health insights. Get your report emailed or download it as a file." | |
) | |
dashboard_tab = gr.Interface( | |
fn=diagnosis_dashboard, | |
inputs=[], | |
outputs=gr.Plot(label="π Diagnosis Distribution"), | |
title="π Usage Dashboard" | |
) | |
logs_tab = gr.Interface( | |
fn=show_logs, | |
inputs=[], | |
outputs=gr.Dataframe(label="π Diagnosis Logs (Latest 100 entries)"), | |
title="π Logs" | |
) | |
feedback_tab = gr.Interface( | |
fn=submit_feedback, | |
inputs=[gr.Textbox(label="π Share your thoughts")], | |
outputs="text", | |
title="π Submit Feedback" | |
) | |
feedback_download_tab = gr.Interface( | |
fn=download_feedback_log, | |
inputs=[], | |
outputs=gr.File(label="π₯ Download All Feedback Logs"), | |
title="π Download Feedback CSV" | |
) | |
agent_tab = gr.Interface( | |
fn=lambda: "", | |
inputs=[], | |
outputs=gr.HTML( | |
"""<button onclick="window.open('https://jaamie-mental-health-agent.hf.space', '_blank')" | |
style='padding:10px 20px; font-size:16px; background-color:#4CAF50; color:white; border:none; border-radius:5px;'> | |
π§ Launch Agent SafeSpace 001 | |
</button>""" | |
), | |
title="π€ Agent SafeSpace 001" | |
) | |
# Add to your tab list | |
app = gr.TabbedInterface( | |
interface_list=[ | |
main_assistant_tab, | |
dashboard_tab, | |
logs_tab, | |
feedback_tab, | |
feedback_download_tab, | |
agent_tab | |
], | |
tab_names=[ | |
"π§ Assistant", | |
"π Dashboard", | |
"π Logs", | |
"π Feedback", | |
"π Feedback CSV", | |
"π€ Agent 001" | |
] | |
) | |
#app.launch(share=True) | |
print("π SafeSpace AI is live!") | |
# Launch the Gradio App | |
if __name__ == "__main__": | |
app.launch() | |