medical_chatbot / app.py
atharvasc27112001's picture
Update app.py
e1e0f62 verified
import os
import re, random, hashlib
import pandas as pd
import numpy as np
import torch
import transformers
import gradio as gr
from torch import nn
from torch.nn.functional import cosine_similarity
# ── Configuration ────────────────────────────────────────────────────────────
MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MIN_FREQ = 4
MAX_LEN = 256
VERBALIZE_LABEL = True
# ── 1) Load & Clean Data ─────────────────────────────────────────────────────
df = pd.read_csv("medquad.csv")
# build text field
df["text"] = df["question"].fillna("").str.strip() + " " + df["answer"].fillna("").str.strip()
df = df.dropna(subset=["text"]).reset_index(drop=True)
# normalize hyphens/spaces in both text and labels
dash_pat = r"[-‐-–—]"
df["text"] = df["text"].str.replace(dash_pat, " ", regex=True)
df["focus_area"] = (
df["focus_area"]
.fillna("")
.astype(str)
.str.replace(dash_pat, " ", regex=True)
.str.lower()
.str.replace(r"\s+", " ", regex=True)
.str.strip()
)
# prune rare labels
vc = df["focus_area"].value_counts()
keep = vc[vc >= MIN_FREQ].index
df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
# ── 2) Tokenizer & Frozen BERT ───────────────────────────────────────────────
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME) \
.to(DEVICE).eval()
# ── 3) Label ↔ ID maps & label embeddings ────────────────────────────────────
def verbalise(lbl: str) -> str:
if VERBALIZE_LABEL:
return f"This question is about the medical focus area of {lbl}."
return lbl
labels = sorted(df["focus_area"].unique())
label2id = {lbl:i for i,lbl in enumerate(labels)}
id2label = {i:lbl for lbl,i in label2id.items()}
@torch.no_grad()
def encode_text(s: str, max_length=MAX_LEN) -> torch.Tensor:
toks = tokenizer(s, return_tensors="pt",
truncation=True, max_length=max_length,
padding=False).to(DEVICE)
out = bert_model(**toks).last_hidden_state[:,0] # CLS
return out.squeeze().cpu()
# precompute one vector per label
label_embs = torch.stack([
encode_text(verbalise(lbl), max_length=32)
for lbl in labels
])
# ── 4) Prediction function ──────────────────────────────────────────────────
def predict_disease(symptoms: str) -> str:
symptoms = symptoms.strip()
if not symptoms:
return "❗️ Please enter your symptoms."
try:
# embed user input
q_emb = encode_text(symptoms).unsqueeze(0) # [1, hidden]
# cosine with each label embedding
sims = cosine_similarity(label_embs, q_emb, dim=1) # [num_labels]
idx = sims.argmax().item()
return labels[idx]
except Exception as e:
return f"Error: {e}"
# ── 5) Gradio App ───────────────────────────────────────────────────────────
app = gr.Interface(
fn=predict_disease,
inputs=gr.Textbox(
lines=3,
placeholder="Enter your symptoms here…"
),
outputs="text",
title="πŸ”¬ Symptomβ†’Disease Chatbot",
description="PubMed-BERT frozen embeddings + cosine similarity"
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
share=False)