import gradio as gr import sqlite3 import bcrypt from datetime import datetime import re import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification import os import logging from openai import OpenAI import json from fpdf import FPDF print("ENV:", os.environ) # 👈 Add this for debugging api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise RuntimeError("OPENAI_API_KEY environment variable not found.") client = OpenAI(api_key=api_key) import json from fpdf import FPDF # -------------------------- # Environment Setup # -------------------------- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # -------------------------- # Global Tokenizer and Hybrid Model for Treatment Prediction # -------------------------- tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") class HybridMentalHealthModel(nn.Module): def __init__(self, bert_model, num_genders, num_medications, num_therapies, hidden_size=128): super(HybridMentalHealthModel, self).__init__() self.bert = AutoModel.from_pretrained(bert_model) bert_output_size = self.bert.config.hidden_size self.age_fc = nn.Linear(1, 16) self.gender_fc = nn.Embedding(num_genders, 16) self.fc = nn.Linear(bert_output_size + 32, hidden_size) self.medication_head = nn.Linear(hidden_size, num_medications) self.therapy_head = nn.Linear(hidden_size, num_therapies) def forward(self, input_ids, attention_mask, age, gender): bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] age_out = self.age_fc(age) gender_out = self.gender_fc(gender) combined = torch.cat((bert_output, age_out, gender_out), dim=1) hidden = torch.relu(self.fc(combined)) return self.medication_head(hidden), self.therapy_head(hidden) # -------------------------- # Global Label Mappings and Age Scaler # -------------------------- medication_classes = ["Anxiolytics", "Benzodiazepines", "Antidepressants", "Mood Stabilizers", "Antipsychotics", "Stimulants"] therapy_classes = ["Cognitive Behavioral Therapy", "Dialectical Behavioral Therapy", "Interpersonal Therapy", "Mindfulness-Based Therapy"] # Update with your types gender_classes = ["Male", "Female", "Other"] medication_encoder = {name: idx for idx, name in enumerate(medication_classes)} inv_medication_encoder = {idx: name for name, idx in medication_encoder.items()} therapy_encoder = {name: idx for idx, name in enumerate(therapy_classes)} inv_therapy_encoder = {idx: name for name, idx in therapy_encoder.items()} gender_encoder = {name: idx for idx, name in enumerate(gender_classes)} mean_age = 50 std_age = 10 def scale_age(age): return (age - mean_age) / std_age # -------------------------- # Load the Hybrid Model (Treatment Prediction) # -------------------------- num_genders = len(gender_classes) num_medications = len(medication_classes) num_therapies = len(therapy_classes) MODEL_SAVE_PATH = "22.03.2025-16.02-ML128E10" # Update accordingly model = HybridMentalHealthModel("emilyalsentzer/Bio_ClinicalBERT", num_genders, num_medications, num_therapies) state_dict = torch.load(MODEL_SAVE_PATH, map_location=device) if "gender_fc.weight" in state_dict: del state_dict["gender_fc.weight"] model.load_state_dict(state_dict, strict=False) model.to(device) model.eval() # -------------------------- # Global Diagnosis Model (Mental Health Diagnosis) # -------------------------- diagnosis_tokenizer = AutoTokenizer.from_pretrained("ethandavey/mental-health-diagnosis-bert") # Update with your model ID diagnosis_model = AutoModelForSequenceClassification.from_pretrained("ethandavey/mental-health-diagnosis-bert") # Update with your model ID diagnosis_model.to(device) diagnosis_model.eval() def predict_disease(text): inputs = diagnosis_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = diagnosis_model(**inputs) probabilities = F.softmax(outputs.logits, dim=1).squeeze() label_mapping = {0: "Anxiety", 1: "Normal", 2: "Depression", 3: "Suicidal", 4: "Stress"} topk = torch.topk(probabilities, k=3) top_preds = [(label_mapping[i.item()], probabilities[i].item()) for i in topk.indices] return top_preds def predict_med_therapy(symptoms, age, gender): encoding = tokenizer(symptoms, return_tensors="pt", truncation=True, padding='max_length', max_length=128) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) age_norm = torch.tensor([[scale_age(age)]], dtype=torch.float32).to(device) gender_idx = gender_encoder.get(gender, 0) gender_tensor = torch.tensor([gender_idx], dtype=torch.long).to(device) with torch.no_grad(): med_logits, therapy_logits = model(input_ids, attention_mask, age_norm, gender_tensor) med_probabilities = torch.softmax(med_logits, dim=1) therapy_probabilities = torch.softmax(therapy_logits, dim=1) med_pred = torch.argmax(med_probabilities, dim=1).item() therapy_pred = torch.argmax(therapy_probabilities, dim=1).item() med_confidence = med_probabilities[0][med_pred].item() therapy_confidence = therapy_probabilities[0][therapy_pred].item() predicted_med = inv_medication_encoder.get(med_pred, "Unknown") predicted_therapy = inv_therapy_encoder.get(therapy_pred, "Unknown") return (predicted_med, med_confidence), (predicted_therapy, therapy_confidence) # -------------------------- # OpenAI Functions (Summarization and Explanation) # -------------------------- def get_concise_rewrite(text, max_tokens, temperature=0.7): messages = [ {"role": "system", "content": "You are an expert rewriting assistant. Rewrite the given statement into a concise version while preserving its tone and vocabulary."}, {"role": "user", "content": text} ] try: response = client.chat.completions.create(model="gpt-4o-mini", messages=messages, max_tokens=max_tokens, temperature=temperature) concise_text = response.choices[0].message.content.strip() except Exception as e: concise_text = f"API call failed: {e}" return concise_text def get_explanation(patient_statement, predicted_diagnosis): messages = [ {"role": "system", "content": "You are an expert mental health assistant. Provide a concise, evidence-based explanation of how the patient's statement supports the diagnosis."}, {"role": "user", "content": f"Patient statement: {patient_statement}\nPredicted diagnosis: {predicted_diagnosis}\nExplain briefly."} ] try: response = client.chat.completions.create(model="gpt-4o-mini", messages=messages, max_tokens=256) explanation = response.choices[0].message.content.strip() except Exception as e: explanation = "API call failed." return explanation # -------------------------- # Database Functions # -------------------------- def init_db(): conn = sqlite3.connect("users.db") c = conn.cursor() c.execute(""" CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, password TEXT NOT NULL, full_name TEXT, email TEXT ) """) c.execute(""" CREATE TABLE IF NOT EXISTS chat_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, message TEXT NOT NULL, response TEXT NOT NULL, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) """) c.execute(""" CREATE TABLE IF NOT EXISTS patient_sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT, patient_name TEXT, age REAL, gender TEXT, symptoms TEXT, diagnosis TEXT, medication TEXT, therapy TEXT, summary TEXT, explanation TEXT, pdf_report TEXT, session_timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, appointment_date DATE ) """) conn.commit() conn.close() def register_user(username, password, full_name, email): if not re.fullmatch(r"[^@]+@[^@]+\.[^@]+", email): return "Invalid email format." if len(password) <= 8: return "Password must be more than 8 characters." conn = sqlite3.connect("users.db") c = conn.cursor() hashed = bcrypt.hashpw(password.encode(), bcrypt.gensalt()) try: c.execute("INSERT INTO users (username, password, full_name, email) VALUES (?, ?, ?, ?)", (username, hashed, full_name, email)) conn.commit() return "User registered successfully." except sqlite3.IntegrityError: return "Username already exists." finally: conn.close() def login_user(username, password): conn = sqlite3.connect("users.db") c = conn.cursor() c.execute("SELECT password FROM users WHERE username = ?", (username,)) user = c.fetchone() conn.close() if user and bcrypt.checkpw(password.encode(), user[0]): return True return False def get_user_info(username): conn = sqlite3.connect("users.db") c = conn.cursor() c.execute("SELECT username, email, full_name FROM users WHERE username = ?", (username,)) user = c.fetchone() conn.close() if user: return f"Username: {user[0]}\nFull Name: {user[2]}\nEmail: {user[1]}" else: return "User info not found." def get_chat_history(username): conn = sqlite3.connect("users.db") c = conn.cursor() c.execute("SELECT message, response, timestamp FROM chat_history WHERE username = ? ORDER BY timestamp DESC", (username,)) history = c.fetchall() conn.close() return history def get_patient_sessions(filter_name="", filter_date=""): conn = sqlite3.connect("users.db") c = conn.cursor() query = "SELECT patient_name, age, gender, symptoms, diagnosis, medication, therapy, summary, explanation, pdf_report, session_timestamp FROM patient_sessions WHERE 1=1" params = [] if filter_name: query += " AND patient_name LIKE ?" params.append(f"%{filter_name}%") if filter_date: query += " AND DATE(session_timestamp)=?" params.append(filter_date) c.execute(query, params) sessions = c.fetchall() conn.close() return sessions def insert_patient_session(session_data): conn = sqlite3.connect("users.db") c = conn.cursor() c.execute(""" INSERT INTO patient_sessions (username, patient_name, age, gender, symptoms, diagnosis, medication, therapy, summary, explanation, pdf_report, appointment_date) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( session_data.get("username"), session_data.get("patient_name"), session_data.get("age"), session_data.get("gender"), session_data.get("symptoms"), session_data.get("diagnosis"), session_data.get("medication"), session_data.get("therapy"), session_data.get("summary"), session_data.get("explanation"), session_data.get("pdf_report"), session_data.get("appointment_date"))) conn.commit() conn.close() # -------------------------- # PDF Report Generation Function # -------------------------- def generate_pdf_report(session_data): pdf = FPDF() pdf.add_page() pdf.set_font("Arial", size=12) # Use safe_text to ensure the title is safe for latin-1 encoding pdf.cell(200, 10, txt=safe_text("Patient Session Report"), ln=True, align='C') pdf.ln(10) for key, value in session_data.items(): # Convert each line to a safe text version before writing it pdf.multi_cell(0, 10, txt=safe_text(f"{key.capitalize()}: {value}")) reports_dir = "reports" os.makedirs(reports_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{reports_dir}/{session_data.get('patient_name')}_{timestamp}.pdf" pdf.output(filename) return filename def safe_text(txt): # Encode the text to latin-1, replacing characters that can't be encoded return txt.encode("latin-1", "replace").decode("latin-1") # -------------------------- # Helper: Autofill Previous Patient Info # -------------------------- def get_previous_patients(): # Use the current logged-in user from user_session.value username = user_session.value conn = sqlite3.connect("users.db") c = conn.cursor() c.execute("SELECT DISTINCT patient_name FROM patient_sessions WHERE username=?", (username,)) records = c.fetchall() conn.close() return [r[0] for r in records] def get_previous_patient_info(selected_patient): # Use the current logged-in user from user_session.value username = user_session.value conn = sqlite3.connect("users.db") c = conn.cursor() c.execute( "SELECT patient_name, age, gender FROM patient_sessions WHERE username=? AND patient_name=? ORDER BY session_timestamp DESC LIMIT 1", (username, selected_patient) ) record = c.fetchone() conn.close() if record: return record[0], record[1], record[2] else: return "", None, "" # -------------------------- # Gradio UI Setup with External CSS # -------------------------- with gr.Blocks(css=open("styles.css", "r").read(), theme="soft") as app: user_session = gr.State(value="") profile_visible = gr.State(value=False) session_data_state = gr.State(value="") with gr.Row(elem_id="header") as header_row: with gr.Column(scale=8): gr.Markdown("## Mental Health Chatbot") with gr.Column(visible=False, elem_id="profile_container") as profile_container: profile_button = gr.Button("👤", elem_id="profile_button", variant="secondary") with gr.Column(visible=False, elem_id="profile_info_box") as profile_info_box: profile_info = gr.HTML() logout_button = gr.Button("Logout", elem_id="logout_button") with gr.Column(visible=True, elem_id="login_page") as login_page: gr.Markdown("## Login") with gr.Row(): username_login = gr.Textbox(label="Username") password_login = gr.Textbox(label="Password", type="password") login_btn = gr.Button("Login") login_output = gr.Textbox(label="Login Status", interactive=False) gr.Markdown("New user? Click below to register.") go_to_register = gr.Button("Go to Register") with gr.Column(visible=False, elem_id="register_page") as register_page: gr.Markdown("## Register") new_username = gr.Textbox(label="New Username") new_password = gr.Textbox(label="New Password", type="password") full_name = gr.Textbox(label="Full Name") email = gr.Textbox(label="Email") register_btn = gr.Button("Register") register_output = gr.Textbox(label="Registration Status", interactive=False) gr.Markdown("Already have an account?") back_to_login = gr.Button("Back to Login") with gr.Tabs(visible=False, elem_id="main_panel") as main_panel: with gr.Tab("Chatbot"): with gr.Row(): with gr.Column(scale=1): previous_patient = gr.Dropdown(label="Previous Patients", choices=[], interactive=True) patient_name_input = gr.Textbox(placeholder="Enter patient name", label="Patient Name") gender_input = gr.Dropdown(choices=list(gender_encoder.keys()), label="Gender") age_input = gr.Number(label="Age") symptoms_input = gr.Textbox(placeholder="Describe symptoms", label="Symptoms", lines=4) submit = gr.Button("Submit") generate_report_btn = gr.Button("Generate Report", visible=False) with gr.Column(scale=1): with gr.Row(): with gr.Column(scale=4, min_width=240): # Textbox column diagnosis_textbox = gr.Textbox(label="Diagnosis", interactive=False) with gr.Column(scale=1, min_width=120): # Confidence column diagnosis_conf_html = gr.HTML(elem_classes=["confidence-container"]) with gr.Row(): with gr.Column(scale=4, min_width=240): medication_textbox = gr.Textbox(label="Medication", interactive=False) with gr.Column(scale=1, min_width=120): medication_conf_html = gr.HTML(elem_classes=["confidence-container"]) with gr.Row(): with gr.Column(scale=4, min_width=240): therapy_textbox = gr.Textbox(label="Therapy", interactive=False) with gr.Column(scale=1, min_width=120): therapy_conf_html = gr.HTML(elem_classes=["confidence-container"]) summary_textbox = gr.Textbox(label="Concise Summary", interactive=False) explanation_textbox = gr.Textbox(label="Explanation", interactive=False) with gr.Row(): report_download = gr.File(label="Download Report", interactive=False) def handle_chat_extended(patient_name, gender, age, symptoms): if age is None or age <= 0: error_msg = "Age must be greater than 0." return (error_msg, "", error_msg, "", error_msg, "", error_msg, error_msg, gr.update(visible=False)) if age > 150: error_msg2 = "Age must be lower than 150" return (error_msg2, "", error_msg2, "", error_msg2, "", error_msg2, error_msg2, gr.update(visible=False)) if len(symptoms.split()) > 512: msg = "Input exceeds maximum allowed length of 512 words." return (msg, "", msg, "", msg, "", msg, msg, gr.update(visible=False)) full_statement = f"Patient Name: {patient_name}, Gender: {gender}, Age: {age}, Symptoms: {symptoms}" summary = get_concise_rewrite(full_statement, max_tokens=150, temperature=0.7) # Predict top 3 diagnoses diagnosis_preds = predict_disease(full_statement) # Now returns list of (label, confidence) diagnosis_display = "\n".join([f"{label}" for label, _ in diagnosis_preds]) def get_confidence_class(percentage): if percentage <= 50: return "confidence-low" elif percentage <= 74: return "confidence-medium" else: return "confidence-high" diagnosis_conf_html_val = "