import gradio as gr import sqlite3 import bcrypt from datetime import datetime import re import torch import torch.nn as nn import torch.nn.functional as F from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification import os import logging from openai import OpenAI import json from fpdf import FPDF print("ENV:", os.environ) # 👈 Add this for debugging api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise RuntimeError("OPENAI_API_KEY environment variable not found.") client = OpenAI(api_key=api_key) import json from fpdf import FPDF # -------------------------- # Environment Setup # -------------------------- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("Using device:", device) # -------------------------- # Global Tokenizer and Hybrid Model for Treatment Prediction # -------------------------- tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") class HybridMentalHealthModel(nn.Module): def __init__(self, bert_model, num_genders, num_medications, num_therapies, hidden_size=128): super(HybridMentalHealthModel, self).__init__() self.bert = AutoModel.from_pretrained(bert_model) bert_output_size = self.bert.config.hidden_size self.age_fc = nn.Linear(1, 16) self.gender_fc = nn.Embedding(num_genders, 16) self.fc = nn.Linear(bert_output_size + 32, hidden_size) self.medication_head = nn.Linear(hidden_size, num_medications) self.therapy_head = nn.Linear(hidden_size, num_therapies) def forward(self, input_ids, attention_mask, age, gender): bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :] age_out = self.age_fc(age) gender_out = self.gender_fc(gender) combined = torch.cat((bert_output, age_out, gender_out), dim=1) hidden = torch.relu(self.fc(combined)) return self.medication_head(hidden), self.therapy_head(hidden) # -------------------------- # Global Label Mappings and Age Scaler # -------------------------- medication_classes = ["Anxiolytics", "Benzodiazepines", "Antidepressants", "Mood Stabilizers", "Antipsychotics", "Stimulants"] therapy_classes = ["Cognitive Behavioral Therapy", "Dialectical Behavioral Therapy", "Interpersonal Therapy", "Mindfulness-Based Therapy"] # Update with your types gender_classes = ["Male", "Female", "Other"] medication_encoder = {name: idx for idx, name in enumerate(medication_classes)} inv_medication_encoder = {idx: name for name, idx in medication_encoder.items()} therapy_encoder = {name: idx for idx, name in enumerate(therapy_classes)} inv_therapy_encoder = {idx: name for name, idx in therapy_encoder.items()} gender_encoder = {name: idx for idx, name in enumerate(gender_classes)} mean_age = 50 std_age = 10 def scale_age(age): return (age - mean_age) / std_age # -------------------------- # Load the Hybrid Model (Treatment Prediction) # -------------------------- num_genders = len(gender_classes) num_medications = len(medication_classes) num_therapies = len(therapy_classes) MODEL_SAVE_PATH = "22.03.2025-16.02-ML128E10" # Update accordingly model = HybridMentalHealthModel("emilyalsentzer/Bio_ClinicalBERT", num_genders, num_medications, num_therapies) state_dict = torch.load(MODEL_SAVE_PATH, map_location=device) if "gender_fc.weight" in state_dict: del state_dict["gender_fc.weight"] model.load_state_dict(state_dict, strict=False) model.to(device) model.eval() # -------------------------- # Global Diagnosis Model (Mental Health Diagnosis) # -------------------------- diagnosis_tokenizer = AutoTokenizer.from_pretrained("ethandavey/mental-health-diagnosis-bert") # Update with your model ID diagnosis_model = AutoModelForSequenceClassification.from_pretrained("ethandavey/mental-health-diagnosis-bert") # Update with your model ID diagnosis_model.to(device) diagnosis_model.eval() def predict_disease(text): inputs = diagnosis_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128) inputs = {k: v.to(device) for k, v in inputs.items()} with torch.no_grad(): outputs = diagnosis_model(**inputs) probabilities = F.softmax(outputs.logits, dim=1).squeeze() label_mapping = {0: "Anxiety", 1: "Normal", 2: "Depression", 3: "Suicidal", 4: "Stress"} topk = torch.topk(probabilities, k=3) top_preds = [(label_mapping[i.item()], probabilities[i].item()) for i in topk.indices] return top_preds def predict_med_therapy(symptoms, age, gender): encoding = tokenizer(symptoms, return_tensors="pt", truncation=True, padding='max_length', max_length=128) input_ids = encoding["input_ids"].to(device) attention_mask = encoding["attention_mask"].to(device) age_norm = torch.tensor([[scale_age(age)]], dtype=torch.float32).to(device) gender_idx = gender_encoder.get(gender, 0) gender_tensor = torch.tensor([gender_idx], dtype=torch.long).to(device) with torch.no_grad(): med_logits, therapy_logits = model(input_ids, attention_mask, age_norm, gender_tensor) med_probabilities = torch.softmax(med_logits, dim=1) therapy_probabilities = torch.softmax(therapy_logits, dim=1) med_pred = torch.argmax(med_probabilities, dim=1).item() therapy_pred = torch.argmax(therapy_probabilities, dim=1).item() med_confidence = med_probabilities[0][med_pred].item() therapy_confidence = therapy_probabilities[0][therapy_pred].item() predicted_med = inv_medication_encoder.get(med_pred, "Unknown") predicted_therapy = inv_therapy_encoder.get(therapy_pred, "Unknown") return (predicted_med, med_confidence), (predicted_therapy, therapy_confidence) # -------------------------- # OpenAI Functions (Summarization and Explanation) # -------------------------- def get_concise_rewrite(text, max_tokens, temperature=0.7): messages = [ {"role": "system", "content": "You are an expert rewriting assistant. Rewrite the given statement into a concise version while preserving its tone and vocabulary."}, {"role": "user", "content": text} ] try: response = client.chat.completions.create(model="gpt-4o-mini", messages=messages, max_tokens=max_tokens, temperature=temperature) concise_text = response.choices[0].message.content.strip() except Exception as e: concise_text = f"API call failed: {e}" return concise_text def get_explanation(patient_statement, predicted_diagnosis): messages = [ {"role": "system", "content": "You are an expert mental health assistant. Provide a concise, evidence-based explanation of how the patient's statement supports the diagnosis."}, {"role": "user", "content": f"Patient statement: {patient_statement}\nPredicted diagnosis: {predicted_diagnosis}\nExplain briefly."} ] try: response = client.chat.completions.create(model="gpt-4o-mini", messages=messages, max_tokens=256) explanation = response.choices[0].message.content.strip() except Exception as e: explanation = "API call failed." return explanation # -------------------------- # Database Functions # -------------------------- def init_db(): conn = sqlite3.connect("users.db") c = conn.cursor() c.execute(""" CREATE TABLE IF NOT EXISTS users ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT UNIQUE NOT NULL, password TEXT NOT NULL, full_name TEXT, email TEXT ) """) c.execute(""" CREATE TABLE IF NOT EXISTS chat_history ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT NOT NULL, message TEXT NOT NULL, response TEXT NOT NULL, timestamp DATETIME DEFAULT CURRENT_TIMESTAMP ) """) c.execute(""" CREATE TABLE IF NOT EXISTS patient_sessions ( id INTEGER PRIMARY KEY AUTOINCREMENT, username TEXT, patient_name TEXT, age REAL, gender TEXT, symptoms TEXT, diagnosis TEXT, medication TEXT, therapy TEXT, summary TEXT, explanation TEXT, pdf_report TEXT, session_timestamp DATETIME DEFAULT CURRENT_TIMESTAMP, appointment_date DATE ) """) conn.commit() conn.close() def register_user(username, password, full_name, email): if not re.fullmatch(r"[^@]+@[^@]+\.[^@]+", email): return "Invalid email format." if len(password) <= 8: return "Password must be more than 8 characters." conn = sqlite3.connect("users.db") c = conn.cursor() hashed = bcrypt.hashpw(password.encode(), bcrypt.gensalt()) try: c.execute("INSERT INTO users (username, password, full_name, email) VALUES (?, ?, ?, ?)", (username, hashed, full_name, email)) conn.commit() return "User registered successfully." except sqlite3.IntegrityError: return "Username already exists." finally: conn.close() def login_user(username, password): conn = sqlite3.connect("users.db") c = conn.cursor() c.execute("SELECT password FROM users WHERE username = ?", (username,)) user = c.fetchone() conn.close() if user and bcrypt.checkpw(password.encode(), user[0]): return True return False def get_user_info(username): conn = sqlite3.connect("users.db") c = conn.cursor() c.execute("SELECT username, email, full_name FROM users WHERE username = ?", (username,)) user = c.fetchone() conn.close() if user: return f"Username: {user[0]}\nFull Name: {user[2]}\nEmail: {user[1]}" else: return "User info not found." def get_chat_history(username): conn = sqlite3.connect("users.db") c = conn.cursor() c.execute("SELECT message, response, timestamp FROM chat_history WHERE username = ? ORDER BY timestamp DESC", (username,)) history = c.fetchall() conn.close() return history def get_patient_sessions(filter_name="", filter_date=""): conn = sqlite3.connect("users.db") c = conn.cursor() query = "SELECT patient_name, age, gender, symptoms, diagnosis, medication, therapy, summary, explanation, pdf_report, session_timestamp FROM patient_sessions WHERE 1=1" params = [] if filter_name: query += " AND patient_name LIKE ?" params.append(f"%{filter_name}%") if filter_date: query += " AND DATE(session_timestamp)=?" params.append(filter_date) c.execute(query, params) sessions = c.fetchall() conn.close() return sessions def insert_patient_session(session_data): conn = sqlite3.connect("users.db") c = conn.cursor() c.execute(""" INSERT INTO patient_sessions (username, patient_name, age, gender, symptoms, diagnosis, medication, therapy, summary, explanation, pdf_report, appointment_date) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """, ( session_data.get("username"), session_data.get("patient_name"), session_data.get("age"), session_data.get("gender"), session_data.get("symptoms"), session_data.get("diagnosis"), session_data.get("medication"), session_data.get("therapy"), session_data.get("summary"), session_data.get("explanation"), session_data.get("pdf_report"), session_data.get("appointment_date"))) conn.commit() conn.close() # -------------------------- # PDF Report Generation Function # -------------------------- def generate_pdf_report(session_data): pdf = FPDF() pdf.add_page() pdf.set_font("Arial", size=12) # Use safe_text to ensure the title is safe for latin-1 encoding pdf.cell(200, 10, txt=safe_text("Patient Session Report"), ln=True, align='C') pdf.ln(10) for key, value in session_data.items(): # Convert each line to a safe text version before writing it pdf.multi_cell(0, 10, txt=safe_text(f"{key.capitalize()}: {value}")) reports_dir = "reports" os.makedirs(reports_dir, exist_ok=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"{reports_dir}/{session_data.get('patient_name')}_{timestamp}.pdf" pdf.output(filename) return filename def safe_text(txt): # Encode the text to latin-1, replacing characters that can't be encoded return txt.encode("latin-1", "replace").decode("latin-1") # -------------------------- # Helper: Autofill Previous Patient Info # -------------------------- def get_previous_patients(): # Use the current logged-in user from user_session.value username = user_session.value conn = sqlite3.connect("users.db") c = conn.cursor() c.execute("SELECT DISTINCT patient_name FROM patient_sessions WHERE username=?", (username,)) records = c.fetchall() conn.close() return [r[0] for r in records] def get_previous_patient_info(selected_patient): # Use the current logged-in user from user_session.value username = user_session.value conn = sqlite3.connect("users.db") c = conn.cursor() c.execute( "SELECT patient_name, age, gender FROM patient_sessions WHERE username=? AND patient_name=? ORDER BY session_timestamp DESC LIMIT 1", (username, selected_patient) ) record = c.fetchone() conn.close() if record: return record[0], record[1], record[2] else: return "", None, "" # -------------------------- # Gradio UI Setup with External CSS # -------------------------- with gr.Blocks(css=open("styles.css", "r").read(), theme="soft") as app: user_session = gr.State(value="") profile_visible = gr.State(value=False) session_data_state = gr.State(value="") with gr.Row(elem_id="header") as header_row: with gr.Column(scale=8): gr.Markdown("## Mental Health Chatbot") with gr.Column(visible=False, elem_id="profile_container") as profile_container: profile_button = gr.Button("👤", elem_id="profile_button", variant="secondary") with gr.Column(visible=False, elem_id="profile_info_box") as profile_info_box: profile_info = gr.HTML() logout_button = gr.Button("Logout", elem_id="logout_button") with gr.Column(visible=True, elem_id="login_page") as login_page: gr.Markdown("## Login") with gr.Row(): username_login = gr.Textbox(label="Username") password_login = gr.Textbox(label="Password", type="password") login_btn = gr.Button("Login") login_output = gr.Textbox(label="Login Status", interactive=False) gr.Markdown("New user? Click below to register.") go_to_register = gr.Button("Go to Register") with gr.Column(visible=False, elem_id="register_page") as register_page: gr.Markdown("## Register") new_username = gr.Textbox(label="New Username") new_password = gr.Textbox(label="New Password", type="password") full_name = gr.Textbox(label="Full Name") email = gr.Textbox(label="Email") register_btn = gr.Button("Register") register_output = gr.Textbox(label="Registration Status", interactive=False) gr.Markdown("Already have an account?") back_to_login = gr.Button("Back to Login") with gr.Tabs(visible=False, elem_id="main_panel") as main_panel: with gr.Tab("Chatbot"): with gr.Row(): with gr.Column(scale=1): previous_patient = gr.Dropdown(label="Previous Patients", choices=[], interactive=True) patient_name_input = gr.Textbox(placeholder="Enter patient name", label="Patient Name") gender_input = gr.Dropdown(choices=list(gender_encoder.keys()), label="Gender") age_input = gr.Number(label="Age") symptoms_input = gr.Textbox(placeholder="Describe symptoms", label="Symptoms", lines=4) submit = gr.Button("Submit") generate_report_btn = gr.Button("Generate Report", visible=False) with gr.Column(scale=1): with gr.Row(): with gr.Column(scale=4, min_width=240): # Textbox column diagnosis_textbox = gr.Textbox(label="Diagnosis", interactive=False) with gr.Column(scale=1, min_width=120): # Confidence column diagnosis_conf_html = gr.HTML(elem_classes=["confidence-container"]) with gr.Row(): with gr.Column(scale=4, min_width=240): medication_textbox = gr.Textbox(label="Medication", interactive=False) with gr.Column(scale=1, min_width=120): medication_conf_html = gr.HTML(elem_classes=["confidence-container"]) with gr.Row(): with gr.Column(scale=4, min_width=240): therapy_textbox = gr.Textbox(label="Therapy", interactive=False) with gr.Column(scale=1, min_width=120): therapy_conf_html = gr.HTML(elem_classes=["confidence-container"]) summary_textbox = gr.Textbox(label="Concise Summary", interactive=False) explanation_textbox = gr.Textbox(label="Explanation", interactive=False) with gr.Row(): report_download = gr.File(label="Download Report", interactive=False) def handle_chat_extended(patient_name, gender, age, symptoms): if age is None or age <= 0: error_msg = "Age must be greater than 0." return (error_msg, "", error_msg, "", error_msg, "", error_msg, error_msg, gr.update(visible=False)) if age > 150: error_msg2 = "Age must be lower than 150" return (error_msg2, "", error_msg2, "", error_msg2, "", error_msg2, error_msg2, gr.update(visible=False)) if len(symptoms.split()) > 512: msg = "Input exceeds maximum allowed length of 512 words." return (msg, "", msg, "", msg, "", msg, msg, gr.update(visible=False)) full_statement = f"Patient Name: {patient_name}, Gender: {gender}, Age: {age}, Symptoms: {symptoms}" summary = get_concise_rewrite(full_statement, max_tokens=150, temperature=0.7) # Predict top 3 diagnoses diagnosis_preds = predict_disease(full_statement) # Now returns list of (label, confidence) diagnosis_display = "\n".join([f"{label}" for label, _ in diagnosis_preds]) def get_confidence_class(percentage): if percentage <= 50: return "confidence-low" elif percentage <= 74: return "confidence-medium" else: return "confidence-high" diagnosis_conf_html_val = "
" + "
".join([ f"
{conf * 100:.1f}% confidence
" for _, conf in diagnosis_preds ]) + "
" # Predict medication and therapy (med_pred, med_conf), (therapy_pred, therapy_conf) = predict_med_therapy(symptoms, age, gender) med_percentage = med_conf * 100 therapy_percentage = therapy_conf * 100 def get_conf_html(percentage): return f"""
{percentage:.1f}% confidence
""" medication_conf_html_val = get_conf_html(med_percentage) therapy_conf_html_val = get_conf_html(therapy_percentage) # Explanation top_diag_labels = ", ".join([label for label, _ in diagnosis_preds]) explanation = get_explanation(full_statement, f"{top_diag_labels}, {med_pred} and {therapy_pred}") # Prepare session data top_diag_with_conf = ", ".join([f"{label} ({conf * 100:.1f}%)" for label, conf in diagnosis_preds]) session_data = { "patient_name": patient_name, "age": age, "gender": gender, "symptoms": symptoms, "diagnosis": top_diag_with_conf, "medication": f"{med_pred} ({med_percentage:.1f}% confidence)", "therapy": f"{therapy_pred} ({therapy_percentage:.1f}% confidence)", "summary": summary, "explanation": explanation, "session_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S") } session_data_state.value = json.dumps(session_data) # Save to chat history conn = sqlite3.connect("users.db") c = conn.cursor() if user_session.value: c.execute("INSERT INTO chat_history (username, message, response) VALUES (?, ?, ?)", (user_session.value, full_statement, top_diag_with_conf)) conn.commit() conn.close() return ( diagnosis_display, diagnosis_conf_html_val, med_pred, medication_conf_html_val, therapy_pred, therapy_conf_html_val, summary, explanation, gr.update(visible=True) ) submit.click(handle_chat_extended, inputs=[patient_name_input, gender_input, age_input, symptoms_input], outputs=[diagnosis_textbox, diagnosis_conf_html, medication_textbox, medication_conf_html, therapy_textbox, therapy_conf_html, summary_textbox, explanation_textbox, generate_report_btn]) def handle_generate_report(): try: # Try to load session data and generate the PDF report. data = json.loads(session_data_state.value) pdf_file = generate_pdf_report(data) data["username"] = user_session.value data["appointment_date"] = "" data["pdf_report"] = pdf_file insert_patient_session(data) return pdf_file except Exception as e: # Create an error file that contains the error message. error_msg = f"Error generating PDF report: {str(e)}" reports_dir = "reports" os.makedirs(reports_dir, exist_ok=True) error_filename = f"{reports_dir}/error_report_{datetime.now().strftime('%Y%m%d_%H%M%S')}.txt" with open(error_filename, "w", encoding="utf-8") as f: f.write(error_msg) return error_filename generate_report_btn.click(handle_generate_report, outputs=report_download) def autofill_previous(selected_patient): name, age_val, gender_val = get_previous_patient_info(selected_patient) return name, age_val, gender_val previous_patient.change(autofill_previous, inputs=[previous_patient], outputs=[patient_name_input, age_input, gender_input]) with gr.Tab("Chat History"): history_output = gr.Textbox(label="Chat History", interactive=False) load_history_btn = gr.Button("Load History") def load_history(): if user_session.value: history = get_chat_history(user_session.value) chat_history_text = "\n".join([f"[{h[2]}] {h[0]}\nBot: {h[1]}" for h in history]) return f"Username: {user_session.value}\n\n{chat_history_text}" else: return "Please log in to view history." load_history_btn.click(load_history, outputs=history_output) with gr.Tab("Book an Appointment"): with gr.Row(): with gr.Column(): patient_name_appt = gr.Textbox(label="Patient Name", placeholder="Enter your name") doctor_name = gr.Dropdown(choices=["Dr. Smith", "Dr. Johnson", "Dr. Lee"], label="Select Doctor") appointment_date = gr.Textbox(label="Appointment Date", placeholder="YYYY-MM-DD") appointment_time = gr.Textbox(label="Appointment Time", placeholder="HH:MM (24-hour format)") reason = gr.TextArea(label="Reason for Visit", placeholder="Describe your symptoms or reason for the visit") book_button = gr.Button("Book Appointment") with gr.Column(): booking_output = gr.Textbox(label="Booking Confirmation", interactive=False) def book_appointment(patient_name, doctor_name, appointment_date, appointment_time, reason): if not user_session.value: return "Please log in to book an appointment." patient_name = (patient_name or "").strip() doctor_name = (doctor_name or "").strip() appointment_date = (appointment_date or "").strip() appointment_time = (appointment_time or "").strip() reason = (reason or "").strip() if not (patient_name and doctor_name and appointment_date and appointment_time and reason): return "Please fill in all the fields." if not re.fullmatch(r"[A-Za-z ]+", patient_name): return "Patient name should contain only letters and spaces." try: # Parse the appointment date and time strings appointment_date_obj = datetime.strptime(appointment_date, "%Y-%m-%d") except ValueError: return "Appointment date must be in YYYY-MM-DD format." try: appointment_time_obj = datetime.strptime(appointment_time, "%H:%M") except ValueError: return "Appointment time must be in HH:MM (24-hour) format." # Combine date and time into a single datetime object appointment_datetime = datetime.combine(appointment_date_obj.date(), appointment_time_obj.time()) now = datetime.now() if appointment_datetime <= now: return "Appointment date/time has already passed. Please select a future date and time." confirmation = (f"Appointment booked for {patient_name} with {doctor_name} on {appointment_date} at {appointment_time}.\n\n" f"Reason: {reason}") return confirmation book_button.click(book_appointment, inputs=[patient_name_appt, doctor_name, appointment_date, appointment_time, reason], outputs=booking_output) with gr.Tab("Patient Sessions"): gr.Markdown("### Search Patient Sessions") search_name = gr.Textbox(label="Patient Name (optional)") search_date = gr.Textbox(label="Date (YYYY-MM-DD, optional)") search_button = gr.Button("Search") sessions_output = gr.Textbox(label="Sessions", interactive=False) def search_sessions(name, date): sessions = get_patient_sessions(filter_name=name, filter_date=date) if sessions: output = "\n\n".join([f"Patient: {s[0]}\nAge: {s[1]}\nGender: {s[2]}\nSymptoms: {s[3]}\nDiagnosis: {s[4]}\nMedication: {s[5]}\nTherapy: {s[6]}\nSummary: {s[7]}\nExplanation: {s[8]}\nReport: {s[9]}\nSession Time: {s[10]}" for s in sessions]) return output else: return "No sessions found." search_button.click(search_sessions, inputs=[search_name, search_date], outputs=sessions_output) def handle_register(username, password, full_name, email): return register_user(username, password, full_name, email) def go_to_register_page(): return gr.update(visible=False), gr.update(visible=True) def back_to_login_page(): return gr.update(visible=True), gr.update(visible=False) go_to_register.click(go_to_register_page, outputs=[login_page, register_page]) register_btn.click(handle_register, inputs=[new_username, new_password, full_name, email], outputs=register_output) back_to_login.click(back_to_login_page, outputs=[login_page, register_page]) def toggle_profile(current_visible): #print("toggle_profile called with user:", user_session.value) # Debug print if not user_session.value: return gr.update(visible=False), False, "" new_visible = not current_visible info = get_user_info(user_session.value) if new_visible else "" return gr.update(visible=new_visible), new_visible, info # Connect profile button click with correct input order: profile_button.click( toggle_profile, inputs=[profile_visible], outputs=[profile_info_box, profile_visible, profile_info] ) # Handle login: update previous patients dropdown def handle_login(username, password): prev_choices = [] if login_user(username, password): user_session.value = username prev_choices = get_previous_patients() return ( f"Welcome, {username}!", # login_output gr.update(visible=True), # main_panel gr.update(visible=False), # login_page gr.update(visible=True), # header_row gr.update(choices=prev_choices, value=None), # previous_patient "", # patient_name_input None, # age_input None, # gender_input "", # symptoms_input "", # diagnosis_textbox "", # diagnosis_conf_html "", # medication_textbox "", # medication_conf_html "", # therapy_textbox "", # therapy_conf_html "", # summary_textbox "", # explanation_textbox gr.update(visible=False), # generate_report_btn None, # report_download "", # session_data_state "", # search_name (Patient Sessions tab) "", # search_date (Patient Sessions tab) "", # booking_output (Book an Appointment tab) "", # patient_name_appt (Booking tab field) "", # appointment_date (Booking tab field) "", # appointment_time (Booking tab field) "", # reason (Booking tab field) gr.update(visible=True) # profile_container: show profile icon now ) else: return ( "Invalid credentials.", # login_output gr.update(), # main_panel gr.update(), # login_page gr.update(), # header_row gr.update(choices=[], value=None), # previous_patient (cleared) "", # patient_name_input None, # age_input None, # gender_input "", # symptoms_input "", # diagnosis_textbox "", # diagnosis_conf_html "", # medication_textbox "", # medication_conf_html "", # therapy_textbox "", # therapy_conf_html "", # summary_textbox "", # explanation_textbox gr.update(visible=False), # generate_report_btn None, # report_download "", # session_data_state "", # search_name "", # search_date "", # booking_output "", # patient_name_appt "", # appointment_date "", # appointment_time "", # reason gr.update(visible=False) # profile_container: hide profile icon on failure ) login_btn.click( handle_login, inputs=[username_login, password_login], outputs=[ login_output, main_panel, login_page, header_row, previous_patient, patient_name_input, age_input, gender_input, symptoms_input, diagnosis_textbox, diagnosis_conf_html, medication_textbox, medication_conf_html, therapy_textbox, therapy_conf_html, summary_textbox, explanation_textbox, generate_report_btn, report_download, session_data_state, search_name, search_date, booking_output, patient_name_appt, appointment_date, appointment_time, reason, profile_container # new output for profile container ] ) def handle_logout(): user_session.value = "" return ( gr.update(visible=False), # Hide main_panel gr.update(visible=True), # Show login_page gr.update(visible=False), # Hide header_row gr.update(visible=False), # Hide profile_info_box False, # Reset profile_visible "", # Clear profile_info "", # Clear login_output "", # Clear history_output "", # Clear username_login textbox "", # Clear password_login textbox "", # Clear new_username textbox (register page) "", # Clear new_password textbox (register page) "", # Clear full_name textbox (register page) "", # Clear email textbox (register page) gr.update(visible=False) # profile_container: hide profile icon ) logout_button.click( handle_logout, outputs=[ main_panel, login_page, header_row, profile_info_box, profile_visible, profile_info, login_output, history_output, username_login, password_login, new_username, new_password, full_name, email ] ) def main(): init_db() app.launch() if __name__ == "__main__": main()