Meshari21 commited on
Commit
c06fe3c
·
verified ·
1 Parent(s): c287feb

create app.py

Browse files
Files changed (1) hide show
  1. app.py +646 -0
app.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import sqlite3
3
+ import bcrypt
4
+ from datetime import datetime
5
+ import re
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
10
+ import os
11
+ import logging
12
+ from dotenv import load_dotenv
13
+ from openai import OpenAI
14
+ load_dotenv() # Loads .env file
15
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
16
+ import json
17
+ from fpdf import FPDF
18
+
19
+ # --------------------------
20
+ # Environment Setup
21
+ # --------------------------
22
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
23
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ print("Using device:", device)
25
+
26
+ # --------------------------
27
+ # Global Tokenizer and Hybrid Model for Treatment Prediction
28
+ # --------------------------
29
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
30
+
31
+
32
+ class HybridMentalHealthModel(nn.Module):
33
+ def __init__(self, bert_model, num_genders, num_medications, num_therapies, hidden_size=128):
34
+ super(HybridMentalHealthModel, self).__init__()
35
+ self.bert = AutoModel.from_pretrained(bert_model)
36
+ bert_output_size = self.bert.config.hidden_size
37
+ self.age_fc = nn.Linear(1, 16)
38
+ self.gender_fc = nn.Embedding(num_genders, 16)
39
+ self.fc = nn.Linear(bert_output_size + 32, hidden_size)
40
+ self.medication_head = nn.Linear(hidden_size, num_medications)
41
+ self.therapy_head = nn.Linear(hidden_size, num_therapies)
42
+
43
+ def forward(self, input_ids, attention_mask, age, gender):
44
+ bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]
45
+ age_out = self.age_fc(age)
46
+ gender_out = self.gender_fc(gender)
47
+ combined = torch.cat((bert_output, age_out, gender_out), dim=1)
48
+ hidden = torch.relu(self.fc(combined))
49
+ return self.medication_head(hidden), self.therapy_head(hidden)
50
+
51
+
52
+ # --------------------------
53
+ # Global Label Mappings and Age Scaler
54
+ # --------------------------
55
+ medication_classes = ["Anxiolytics", "Benzodiazepines", "Antidepressants", "Mood Stabilizers", "Antipsychotics", "Stimulants"]
56
+ therapy_classes = ["Cognitive Behavioral Therapy", "Dialectical Behavioral Therapy", "Interpersonal Therapy", "Mindfulness-Based Therapy"] # Update with your types
57
+ gender_classes = ["Male", "Female", "Other"]
58
+
59
+ medication_encoder = {name: idx for idx, name in enumerate(medication_classes)}
60
+ inv_medication_encoder = {idx: name for name, idx in medication_encoder.items()}
61
+ therapy_encoder = {name: idx for idx, name in enumerate(therapy_classes)}
62
+ inv_therapy_encoder = {idx: name for name, idx in therapy_encoder.items()}
63
+ gender_encoder = {name: idx for idx, name in enumerate(gender_classes)}
64
+
65
+ mean_age = 50
66
+ std_age = 10
67
+
68
+ def scale_age(age):
69
+ return (age - mean_age) / std_age
70
+
71
+ # --------------------------
72
+ # Load the Hybrid Model (Treatment Prediction)
73
+ # --------------------------
74
+ num_genders = len(gender_classes)
75
+ num_medications = len(medication_classes)
76
+ num_therapies = len(therapy_classes)
77
+ MODEL_SAVE_PATH = "22.03.2025-16.02-ML128E10" # Update accordingly
78
+
79
+ model = HybridMentalHealthModel("emilyalsentzer/Bio_ClinicalBERT", num_genders, num_medications, num_therapies)
80
+ state_dict = torch.load(MODEL_SAVE_PATH, map_location=device)
81
+ if "gender_fc.weight" in state_dict:
82
+ del state_dict["gender_fc.weight"]
83
+ model.load_state_dict(state_dict, strict=False)
84
+ model.to(device)
85
+ model.eval()
86
+
87
+ # --------------------------
88
+ # Global Diagnosis Model (Mental Health Diagnosis)
89
+ # --------------------------
90
+ diagnosis_tokenizer = AutoTokenizer.from_pretrained("ethandavey/mental-health-diagnosis-bert") # Update with your model ID
91
+ diagnosis_model = AutoModelForSequenceClassification.from_pretrained("ethandavey/mental-health-diagnosis-bert") # Update with your model ID
92
+ diagnosis_model.to(device)
93
+ diagnosis_model.eval()
94
+
95
+ def predict_disease(text):
96
+ inputs = diagnosis_tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=128)
97
+ inputs = {k: v.to(device) for k, v in inputs.items()}
98
+ with torch.no_grad():
99
+ outputs = diagnosis_model(**inputs)
100
+ probabilities = F.softmax(outputs.logits, dim=1).squeeze()
101
+ label_mapping = {0: "Anxiety", 1: "Normal", 2: "Depression", 3: "Suicidal", 4: "Stress"}
102
+
103
+ topk = torch.topk(probabilities, k=3)
104
+ top_preds = [(label_mapping[i.item()], probabilities[i].item()) for i in topk.indices]
105
+ return top_preds
106
+
107
+
108
+ def predict_med_therapy(symptoms, age, gender):
109
+ encoding = tokenizer(symptoms, return_tensors="pt", truncation=True, padding='max_length', max_length=128)
110
+ input_ids = encoding["input_ids"].to(device)
111
+ attention_mask = encoding["attention_mask"].to(device)
112
+ age_norm = torch.tensor([[scale_age(age)]], dtype=torch.float32).to(device)
113
+ gender_idx = gender_encoder.get(gender, 0)
114
+ gender_tensor = torch.tensor([gender_idx], dtype=torch.long).to(device)
115
+ with torch.no_grad():
116
+ med_logits, therapy_logits = model(input_ids, attention_mask, age_norm, gender_tensor)
117
+ med_probabilities = torch.softmax(med_logits, dim=1)
118
+ therapy_probabilities = torch.softmax(therapy_logits, dim=1)
119
+ med_pred = torch.argmax(med_probabilities, dim=1).item()
120
+ therapy_pred = torch.argmax(therapy_probabilities, dim=1).item()
121
+ med_confidence = med_probabilities[0][med_pred].item()
122
+ therapy_confidence = therapy_probabilities[0][therapy_pred].item()
123
+ predicted_med = inv_medication_encoder.get(med_pred, "Unknown")
124
+ predicted_therapy = inv_therapy_encoder.get(therapy_pred, "Unknown")
125
+ return (predicted_med, med_confidence), (predicted_therapy, therapy_confidence)
126
+
127
+ # --------------------------
128
+ # OpenAI Functions (Summarization and Explanation)
129
+ # --------------------------
130
+ def get_concise_rewrite(text, max_tokens, temperature=0.7):
131
+ messages = [
132
+ {"role": "system", "content": "You are an expert rewriting assistant. Rewrite the given statement into a concise version while preserving its tone and vocabulary."},
133
+ {"role": "user", "content": text}
134
+ ]
135
+ try:
136
+ response = client.chat.completions.create(model="gpt-4o-mini", messages=messages, max_tokens=max_tokens, temperature=temperature)
137
+ concise_text = response.choices[0].message.content.strip()
138
+ except Exception as e:
139
+ concise_text = f"API call failed: {e}"
140
+ return concise_text
141
+
142
+ def get_explanation(patient_statement, predicted_diagnosis):
143
+ messages = [
144
+ {"role": "system", "content": "You are an expert mental health assistant. Provide a concise, evidence-based explanation of how the patient's statement supports the diagnosis."},
145
+ {"role": "user", "content": f"Patient statement: {patient_statement}\nPredicted diagnosis: {predicted_diagnosis}\nExplain briefly."}
146
+ ]
147
+ try:
148
+ response = client.chat.completions.create(model="gpt-4o-mini", messages=messages, max_tokens=256)
149
+ explanation = response.choices[0].message.content.strip()
150
+ except Exception as e:
151
+ explanation = "API call failed."
152
+ return explanation
153
+
154
+ # --------------------------
155
+ # Database Functions
156
+ # --------------------------
157
+ def init_db():
158
+ conn = sqlite3.connect("users.db")
159
+ c = conn.cursor()
160
+ c.execute("""
161
+ CREATE TABLE IF NOT EXISTS users (
162
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
163
+ username TEXT UNIQUE NOT NULL,
164
+ password TEXT NOT NULL,
165
+ full_name TEXT,
166
+ email TEXT
167
+ )
168
+ """)
169
+ c.execute("""
170
+ CREATE TABLE IF NOT EXISTS chat_history (
171
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
172
+ username TEXT NOT NULL,
173
+ message TEXT NOT NULL,
174
+ response TEXT NOT NULL,
175
+ timestamp DATETIME DEFAULT CURRENT_TIMESTAMP
176
+ )
177
+ """)
178
+ c.execute("""
179
+ CREATE TABLE IF NOT EXISTS patient_sessions (
180
+ id INTEGER PRIMARY KEY AUTOINCREMENT,
181
+ username TEXT,
182
+ patient_name TEXT,
183
+ age REAL,
184
+ gender TEXT,
185
+ symptoms TEXT,
186
+ diagnosis TEXT,
187
+ medication TEXT,
188
+ therapy TEXT,
189
+ summary TEXT,
190
+ explanation TEXT,
191
+ pdf_report TEXT,
192
+ session_timestamp DATETIME DEFAULT CURRENT_TIMESTAMP,
193
+ appointment_date DATE
194
+ )
195
+ """)
196
+ conn.commit()
197
+ conn.close()
198
+
199
+ def register_user(username, password, full_name, email):
200
+ if not re.fullmatch(r"[^@]+@[^@]+\.[^@]+", email):
201
+ return "Invalid email format."
202
+ if len(password) <= 8:
203
+ return "Password must be more than 8 characters."
204
+ conn = sqlite3.connect("users.db")
205
+ c = conn.cursor()
206
+ hashed = bcrypt.hashpw(password.encode(), bcrypt.gensalt())
207
+ try:
208
+ c.execute("INSERT INTO users (username, password, full_name, email) VALUES (?, ?, ?, ?)", (username, hashed, full_name, email))
209
+ conn.commit()
210
+ return "User registered successfully."
211
+ except sqlite3.IntegrityError:
212
+ return "Username already exists."
213
+ finally:
214
+ conn.close()
215
+
216
+ def login_user(username, password):
217
+ conn = sqlite3.connect("users.db")
218
+ c = conn.cursor()
219
+ c.execute("SELECT password FROM users WHERE username = ?", (username,))
220
+ user = c.fetchone()
221
+ conn.close()
222
+ if user and bcrypt.checkpw(password.encode(), user[0]):
223
+ return True
224
+ return False
225
+
226
+ def get_user_info(username):
227
+ conn = sqlite3.connect("users.db")
228
+ c = conn.cursor()
229
+ c.execute("SELECT username, email, full_name FROM users WHERE username = ?", (username,))
230
+ user = c.fetchone()
231
+ conn.close()
232
+ if user:
233
+ return f"Username: {user[0]}\nFull Name: {user[2]}\nEmail: {user[1]}"
234
+ else:
235
+ return "User info not found."
236
+
237
+ def get_chat_history(username):
238
+ conn = sqlite3.connect("users.db")
239
+ c = conn.cursor()
240
+ c.execute("SELECT message, response, timestamp FROM chat_history WHERE username = ? ORDER BY timestamp DESC", (username,))
241
+ history = c.fetchall()
242
+ conn.close()
243
+ return history
244
+
245
+ def get_patient_sessions(filter_name="", filter_date=""):
246
+ conn = sqlite3.connect("users.db")
247
+ c = conn.cursor()
248
+ query = "SELECT patient_name, age, gender, symptoms, diagnosis, medication, therapy, summary, explanation, pdf_report, session_timestamp FROM patient_sessions WHERE 1=1"
249
+ params = []
250
+ if filter_name:
251
+ query += " AND patient_name LIKE ?"
252
+ params.append(f"%{filter_name}%")
253
+ if filter_date:
254
+ query += " AND DATE(session_timestamp)=?"
255
+ params.append(filter_date)
256
+ c.execute(query, params)
257
+ sessions = c.fetchall()
258
+ conn.close()
259
+ return sessions
260
+
261
+ def insert_patient_session(session_data):
262
+ conn = sqlite3.connect("users.db")
263
+ c = conn.cursor()
264
+ c.execute("""
265
+ INSERT INTO patient_sessions (username, patient_name, age, gender, symptoms, diagnosis, medication, therapy, summary, explanation, pdf_report, appointment_date)
266
+ VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
267
+ """, (
268
+ session_data.get("username"), session_data.get("patient_name"), session_data.get("age"), session_data.get("gender"),
269
+ session_data.get("symptoms"), session_data.get("diagnosis"), session_data.get("medication"),
270
+ session_data.get("therapy"), session_data.get("summary"), session_data.get("explanation"),
271
+ session_data.get("pdf_report"), session_data.get("appointment_date")))
272
+ conn.commit()
273
+ conn.close()
274
+
275
+ # --------------------------
276
+ # PDF Report Generation Function
277
+ # --------------------------
278
+ def generate_pdf_report(session_data):
279
+ pdf = FPDF()
280
+ pdf.add_page()
281
+ pdf.set_font("Arial", size=12)
282
+ pdf.cell(200, 10, txt="Patient Session Report", ln=True, align='C')
283
+ pdf.ln(10)
284
+ for key, value in session_data.items():
285
+ pdf.multi_cell(0, 10, txt=f"{key.capitalize()}: {value}")
286
+ reports_dir = "reports"
287
+ os.makedirs(reports_dir, exist_ok=True)
288
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
289
+ filename = f"{reports_dir}/{session_data.get('patient_name')}_{timestamp}.pdf"
290
+ pdf.output(filename)
291
+ return filename
292
+
293
+ # --------------------------
294
+ # Helper: Autofill Previous Patient Info
295
+ # --------------------------
296
+ def get_previous_patient_info(selected_patient):
297
+ conn = sqlite3.connect("users.db")
298
+ c = conn.cursor()
299
+ c.execute("SELECT patient_name, age, gender FROM patient_sessions WHERE patient_name=? ORDER BY session_timestamp DESC LIMIT 1", (selected_patient,))
300
+ record = c.fetchone()
301
+ conn.close()
302
+ if record:
303
+ return record[0], record[1], record[2]
304
+ else:
305
+ return "", None, ""
306
+
307
+ def get_previous_patients():
308
+ conn = sqlite3.connect("users.db")
309
+ c = conn.cursor()
310
+ c.execute("SELECT DISTINCT patient_name FROM patient_sessions")
311
+ records = c.fetchall()
312
+ conn.close()
313
+ return [r[0] for r in records]
314
+
315
+ # --------------------------
316
+ # Gradio UI Setup with External CSS
317
+ # --------------------------
318
+ with gr.Blocks(css=open("styles.css", "r").read(), theme="soft") as app:
319
+ user_session = gr.State(value="")
320
+ profile_visible = gr.State(value=False)
321
+ session_data_state = gr.State(value="")
322
+
323
+ with gr.Row(elem_id="header") as header_row:
324
+ with gr.Column(scale=8):
325
+ gr.Markdown("## Mental Health Chatbot")
326
+ with gr.Column(scale=4) as profile_container:
327
+ profile_button = gr.Button("👤", elem_id="profile_button", variant="secondary")
328
+ with gr.Column(visible=False, elem_id="profile_info_box") as profile_info_box:
329
+ profile_info = gr.HTML()
330
+ logout_button = gr.Button("Logout", elem_id="logout_button")
331
+
332
+ with gr.Column(visible=True, elem_id="login_page") as login_page:
333
+ gr.Markdown("## Login")
334
+ with gr.Row():
335
+ username_login = gr.Textbox(label="Username")
336
+ password_login = gr.Textbox(label="Password", type="password")
337
+ login_btn = gr.Button("Login")
338
+ login_output = gr.Textbox(label="Login Status", interactive=False)
339
+ gr.Markdown("New user? Click below to register.")
340
+ go_to_register = gr.Button("Go to Register")
341
+
342
+ with gr.Column(visible=False, elem_id="register_page") as register_page:
343
+ gr.Markdown("## Register")
344
+ new_username = gr.Textbox(label="New Username")
345
+ new_password = gr.Textbox(label="New Password", type="password")
346
+ full_name = gr.Textbox(label="Full Name")
347
+ email = gr.Textbox(label="Email")
348
+ register_btn = gr.Button("Register")
349
+ register_output = gr.Textbox(label="Registration Status", interactive=False)
350
+ gr.Markdown("Already have an account?")
351
+ back_to_login = gr.Button("Back to Login")
352
+
353
+ with gr.Tabs(visible=False, elem_id="main_panel") as main_panel:
354
+ with gr.Tab("Chatbot"):
355
+ with gr.Row():
356
+ with gr.Column(scale=1):
357
+ previous_patient = gr.Dropdown(label="Previous Patients", choices=[], interactive=True)
358
+ patient_name_input = gr.Textbox(placeholder="Enter patient name", label="Patient Name")
359
+ gender_input = gr.Dropdown(choices=list(gender_encoder.keys()), label="Gender")
360
+ age_input = gr.Number(label="Age")
361
+ symptoms_input = gr.Textbox(placeholder="Describe symptoms", label="Symptoms", lines=4)
362
+ submit = gr.Button("Submit")
363
+ generate_report_btn = gr.Button("Generate Report", visible=False)
364
+ with gr.Column(scale=1):
365
+ with gr.Row():
366
+ with gr.Column(scale=4, min_width=240): # Textbox column
367
+ diagnosis_textbox = gr.Textbox(label="Diagnosis",
368
+ interactive=False)
369
+ with gr.Column(scale=1, min_width=120): # Confidence column
370
+ diagnosis_conf_html = gr.HTML(elem_classes=["confidence-container"])
371
+
372
+ with gr.Row():
373
+ with gr.Column(scale=4, min_width=240):
374
+ medication_textbox = gr.Textbox(label="Medication",
375
+ interactive=False)
376
+ with gr.Column(scale=1, min_width=120):
377
+ medication_conf_html = gr.HTML(elem_classes=["confidence-container"])
378
+
379
+ with gr.Row():
380
+ with gr.Column(scale=4, min_width=240):
381
+ therapy_textbox = gr.Textbox(label="Therapy",
382
+ interactive=False)
383
+ with gr.Column(scale=1, min_width=120):
384
+ therapy_conf_html = gr.HTML(elem_classes=["confidence-container"])
385
+ summary_textbox = gr.Textbox(label="Concise Summary", interactive=False)
386
+ explanation_textbox = gr.Textbox(label="Explanation", interactive=False)
387
+ with gr.Row():
388
+ report_download = gr.File(label="Download Report", interactive=False)
389
+
390
+ def handle_chat_extended(patient_name, gender, age, symptoms):
391
+ if age is None or age <= 0:
392
+ error_msg = "Age must be greater than 0."
393
+ return (error_msg, "", error_msg, "", error_msg, "", error_msg, error_msg, gr.update(visible=False))
394
+
395
+ if age > 150:
396
+ error_msg2 = "Age must be lower than 150"
397
+ return (error_msg2, "", error_msg2, "", error_msg2, "", error_msg2, error_msg2, gr.update(visible=False))
398
+
399
+ if len(symptoms.split()) > 512:
400
+ msg = "Input exceeds maximum allowed length of 512 words."
401
+ return (msg, "", msg, "", msg, "", msg, msg, gr.update(visible=False))
402
+
403
+ full_statement = f"Patient Name: {patient_name}, Gender: {gender}, Age: {age}, Symptoms: {symptoms}"
404
+ summary = get_concise_rewrite(full_statement, max_tokens=150, temperature=0.7)
405
+
406
+ # Predict top 3 diagnoses
407
+ diagnosis_preds = predict_disease(full_statement) # Now returns list of (label, confidence)
408
+ diagnosis_display = "\n".join([f"{label}" for label, _ in diagnosis_preds])
409
+
410
+ def get_confidence_class(percentage):
411
+ if percentage <= 50:
412
+ return "confidence-low"
413
+ elif percentage <= 74:
414
+ return "confidence-medium"
415
+ else:
416
+ return "confidence-high"
417
+
418
+ diagnosis_conf_html_val = "<div class='confidence-multi'>" + "<br>".join([
419
+ f"<div class='confidence-display'><span class='confidence-value {get_confidence_class(conf * 100)}'>{conf * 100:.1f}% confidence</span></div>"
420
+ for _, conf in diagnosis_preds
421
+ ]) + "</div>"
422
+
423
+ # Predict medication and therapy
424
+ (med_pred, med_conf), (therapy_pred, therapy_conf) = predict_med_therapy(symptoms, age, gender)
425
+ med_percentage = med_conf * 100
426
+ therapy_percentage = therapy_conf * 100
427
+
428
+ def get_conf_html(percentage):
429
+ return f"""
430
+ <div class="confidence-display">
431
+ <span class="confidence-value {get_confidence_class(percentage)}">
432
+ {percentage:.1f}% confidence
433
+ </span>
434
+ </div>
435
+ """
436
+
437
+ medication_conf_html_val = get_conf_html(med_percentage)
438
+ therapy_conf_html_val = get_conf_html(therapy_percentage)
439
+
440
+ # Explanation
441
+ top_diag_labels = ", ".join([label for label, _ in diagnosis_preds])
442
+ explanation = get_explanation(full_statement, f"{top_diag_labels}, {med_pred} and {therapy_pred}")
443
+
444
+ # Prepare session data
445
+ top_diag_with_conf = ", ".join([f"{label} ({conf * 100:.1f}%)" for label, conf in diagnosis_preds])
446
+ session_data = {
447
+ "patient_name": patient_name,
448
+ "age": age,
449
+ "gender": gender,
450
+ "symptoms": symptoms,
451
+ "diagnosis": top_diag_with_conf,
452
+ "medication": f"{med_pred} ({med_percentage:.1f}% confidence)",
453
+ "therapy": f"{therapy_pred} ({therapy_percentage:.1f}% confidence)",
454
+ "summary": summary,
455
+ "explanation": explanation,
456
+ "session_timestamp": datetime.now().strftime("%Y-%m-%d %H:%M:%S")
457
+ }
458
+ session_data_state.value = json.dumps(session_data)
459
+
460
+ # Save to chat history
461
+ conn = sqlite3.connect("users.db")
462
+ c = conn.cursor()
463
+ if user_session.value:
464
+ c.execute("INSERT INTO chat_history (username, message, response) VALUES (?, ?, ?)",
465
+ (user_session.value, full_statement, top_diag_with_conf))
466
+ conn.commit()
467
+ conn.close()
468
+
469
+ return (
470
+ diagnosis_display, diagnosis_conf_html_val,
471
+ med_pred, medication_conf_html_val,
472
+ therapy_pred, therapy_conf_html_val,
473
+ summary, explanation,
474
+ gr.update(visible=True)
475
+ )
476
+
477
+
478
+ submit.click(handle_chat_extended,
479
+ inputs=[patient_name_input, gender_input, age_input, symptoms_input],
480
+ outputs=[diagnosis_textbox, diagnosis_conf_html, medication_textbox, medication_conf_html,
481
+ therapy_textbox, therapy_conf_html, summary_textbox, explanation_textbox,
482
+ generate_report_btn])
483
+
484
+ def handle_generate_report():
485
+ try:
486
+ data = json.loads(session_data_state.value)
487
+ except:
488
+ return None
489
+ pdf_file = generate_pdf_report(data)
490
+ data["username"] = user_session.value
491
+ data["appointment_date"] = ""
492
+ data["pdf_report"] = pdf_file
493
+ insert_patient_session(data)
494
+ return pdf_file
495
+
496
+ generate_report_btn.click(handle_generate_report, outputs=report_download)
497
+
498
+ def autofill_previous(selected_patient):
499
+ name, age_val, gender_val = get_previous_patient_info(selected_patient)
500
+ return name, age_val, gender_val
501
+
502
+ previous_patient.change(autofill_previous,
503
+ inputs=[previous_patient],
504
+ outputs=[patient_name_input, age_input, gender_input])
505
+
506
+ with gr.Tab("Chat History"):
507
+ history_output = gr.Textbox(label="Chat History", interactive=False)
508
+ load_history_btn = gr.Button("Load History")
509
+
510
+ def load_history():
511
+ if user_session.value:
512
+ history = get_chat_history(user_session.value)
513
+ return "\n".join([f"[{h[2]}] {h[0]}\nBot: {h[1]}" for h in history])
514
+ else:
515
+ return "Please log in to view history."
516
+
517
+ load_history_btn.click(load_history, outputs=history_output)
518
+
519
+ with gr.Tab("Book an Appointment"):
520
+ with gr.Row():
521
+ with gr.Column():
522
+ patient_name_appt = gr.Textbox(label="Patient Name", placeholder="Enter your name")
523
+ doctor_name = gr.Dropdown(choices=["Dr. Smith", "Dr. Johnson", "Dr. Lee"], label="Select Doctor")
524
+ appointment_date = gr.Textbox(label="Appointment Date", placeholder="YYYY-MM-DD")
525
+ appointment_time = gr.Textbox(label="Appointment Time", placeholder="HH:MM (24-hour format)")
526
+ reason = gr.TextArea(label="Reason for Visit", placeholder="Describe your symptoms or reason for the visit")
527
+ book_button = gr.Button("Book Appointment")
528
+ with gr.Column():
529
+ booking_output = gr.Textbox(label="Booking Confirmation", interactive=False)
530
+
531
+ def book_appointment(patient_name, doctor_name, appointment_date, appointment_time, reason):
532
+ if not user_session.value:
533
+ return "Please log in to book an appointment."
534
+ patient_name = (patient_name or "").strip()
535
+ doctor_name = (doctor_name or "").strip()
536
+ appointment_date = (appointment_date or "").strip()
537
+ appointment_time = (appointment_time or "").strip()
538
+ reason = (reason or "").strip()
539
+ if not (patient_name and doctor_name and appointment_date and appointment_time and reason):
540
+ return "Please fill in all the fields."
541
+ if not re.fullmatch(r"[A-Za-z ]+", patient_name):
542
+ return "Patient name should contain only letters and spaces."
543
+ try:
544
+ datetime.strptime(appointment_date, "%Y-%m-%d")
545
+ except ValueError:
546
+ return "Appointment date must be in YYYY-MM-DD format."
547
+ try:
548
+ datetime.strptime(appointment_time, "%H:%M")
549
+ except ValueError:
550
+ return "Appointment time must be in HH:MM (24-hour) format."
551
+ confirmation = (f"Appointment booked for {patient_name} with {doctor_name} on {appointment_date} at {appointment_time}.\n\n"
552
+ f"Reason: {reason}")
553
+ return confirmation
554
+
555
+ book_button.click(book_appointment,
556
+ inputs=[patient_name_appt, doctor_name, appointment_date, appointment_time, reason],
557
+ outputs=booking_output)
558
+
559
+ with gr.Tab("Patient Sessions"):
560
+ gr.Markdown("### Search Patient Sessions")
561
+ search_name = gr.Textbox(label="Patient Name (optional)")
562
+ search_date = gr.Textbox(label="Date (YYYY-MM-DD, optional)")
563
+ search_button = gr.Button("Search")
564
+ sessions_output = gr.Textbox(label="Sessions", interactive=False)
565
+
566
+ def search_sessions(name, date):
567
+ sessions = get_patient_sessions(filter_name=name, filter_date=date)
568
+ if sessions:
569
+ output = "\n\n".join([f"Patient: {s[0]}\nAge: {s[1]}\nGender: {s[2]}\nSymptoms: {s[3]}\nDiagnosis: {s[4]}\nMedication: {s[5]}\nTherapy: {s[6]}\nSummary: {s[7]}\nExplanation: {s[8]}\nReport: {s[9]}\nSession Time: {s[10]}" for s in sessions])
570
+ return output
571
+ else:
572
+ return "No sessions found."
573
+
574
+ search_button.click(search_sessions, inputs=[search_name, search_date], outputs=sessions_output)
575
+
576
+ def handle_login(username, password):
577
+ if login_user(username, password):
578
+ user_session.value = username
579
+ prev_choices = get_previous_patients()
580
+ return (f"Welcome, {username}!",
581
+ gr.update(visible=True),
582
+ gr.update(visible=False),
583
+ gr.update(visible=True),
584
+ gr.update(choices=prev_choices))
585
+ else:
586
+ return "Invalid credentials.", gr.update(), gr.update(), gr.update(), gr.update()
587
+
588
+ def handle_register(username, password, full_name, email):
589
+ return register_user(username, password, full_name, email)
590
+
591
+ def go_to_register_page():
592
+ return gr.update(visible=False), gr.update(visible=True)
593
+
594
+ def back_to_login_page():
595
+ return gr.update(visible=True), gr.update(visible=False)
596
+
597
+ login_btn.click(handle_login,
598
+ inputs=[username_login, password_login],
599
+ outputs=[login_output, main_panel, login_page, header_row])
600
+ go_to_register.click(go_to_register_page, outputs=[login_page, register_page])
601
+ register_btn.click(handle_register,
602
+ inputs=[new_username, new_password, full_name, email],
603
+ outputs=register_output)
604
+ back_to_login.click(back_to_login_page, outputs=[login_page, register_page])
605
+
606
+
607
+ # Toggle profile function
608
+ def toggle_profile(user, current_visible):
609
+ if not user:
610
+ return gr.update(visible=False), False, ""
611
+ new_visible = not current_visible
612
+ info = get_user_info(user) if new_visible else ""
613
+ return gr.update(visible=new_visible), new_visible, info
614
+
615
+
616
+ # Connect profile button click with correct input order:
617
+ profile_button.click(
618
+ toggle_profile,
619
+ inputs=[user_session, profile_visible],
620
+ outputs=[profile_info_box, profile_visible, profile_info]
621
+ )
622
+
623
+
624
+ # Handle login: update previous patients dropdown
625
+ def handle_login(username, password):
626
+ if login_user(username, password):
627
+ user_session.value = username
628
+ prev_choices = get_previous_patients()
629
+ return (f"Welcome, {username}!",
630
+ gr.update(visible=True), # main_panel visible
631
+ gr.update(visible=False), # login_page hidden
632
+ gr.update(visible=True), # header_row visible
633
+ gr.update(choices=prev_choices)) # update dropdown choices
634
+ else:
635
+ return "Invalid credentials.", gr.update(), gr.update(), gr.update(), gr.update()
636
+
637
+
638
+ # Connect login button click:
639
+ login_btn.click(
640
+ handle_login,
641
+ inputs=[username_login, password_login],
642
+ outputs=[login_output, main_panel, login_page, header_row, previous_patient]
643
+ )
644
+
645
+ init_db()
646
+ app.launch()