File size: 4,077 Bytes
c3c7cca
e1e0f62
d1af83e
e1e0f62
d1af83e
 
 
e1e0f62
 
d1af83e
e1e0f62
 
 
 
 
 
d1af83e
e1e0f62
d1af83e
e1e0f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14b56fb
e1e0f62
 
d1af83e
e1e0f62
 
 
 
 
 
 
 
 
 
 
 
d1af83e
 
 
e1e0f62
d1af83e
 
e1e0f62
 
 
 
 
 
d1af83e
e1e0f62
 
 
 
 
d1af83e
e1e0f62
 
 
 
14b56fb
 
e1e0f62
 
 
 
 
 
14b56fb
 
d1af83e
e1e0f62
c3c7cca
d1af83e
 
 
c3c7cca
d1af83e
 
c3c7cca
e1e0f62
d1af83e
 
e1e0f62
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import re, random, hashlib
import pandas as pd
import numpy as np
import torch
import transformers
import gradio as gr
from torch import nn
from torch.nn.functional import cosine_similarity

# ── Configuration ────────────────────────────────────────────────────────────
MODEL_NAME      = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
DEVICE          = "cuda" if torch.cuda.is_available() else "cpu"
MIN_FREQ        = 4
MAX_LEN         = 256
VERBALIZE_LABEL = True

# ── 1) Load & Clean Data ─────────────────────────────────────────────────────
df = pd.read_csv("medquad.csv")
# build text field
df["text"] = df["question"].fillna("").str.strip() + " " + df["answer"].fillna("").str.strip()
df = df.dropna(subset=["text"]).reset_index(drop=True)

# normalize hyphens/spaces in both text and labels
dash_pat = r"[-‐-–—]"
df["text"] = df["text"].str.replace(dash_pat, " ", regex=True)
df["focus_area"] = (
    df["focus_area"]
      .fillna("")
      .astype(str)
      .str.replace(dash_pat, " ", regex=True)
      .str.lower()
      .str.replace(r"\s+", " ", regex=True)
      .str.strip()
)

# prune rare labels
vc = df["focus_area"].value_counts()
keep = vc[vc >= MIN_FREQ].index
df = df[df["focus_area"].isin(keep)].reset_index(drop=True)

# ── 2) Tokenizer & Frozen BERT ───────────────────────────────────────────────
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME) \
                        .to(DEVICE).eval()

# ── 3) Label ↔ ID maps & label embeddings ────────────────────────────────────
def verbalise(lbl: str) -> str:
    if VERBALIZE_LABEL:
        return f"This question is about the medical focus area of {lbl}."
    return lbl

labels   = sorted(df["focus_area"].unique())
label2id = {lbl:i for i,lbl in enumerate(labels)}
id2label = {i:lbl for lbl,i in label2id.items()}

@torch.no_grad()
def encode_text(s: str, max_length=MAX_LEN) -> torch.Tensor:
    toks = tokenizer(s, return_tensors="pt",
                     truncation=True, max_length=max_length,
                     padding=False).to(DEVICE)
    out = bert_model(**toks).last_hidden_state[:,0]  # CLS
    return out.squeeze().cpu()

# precompute one vector per label
label_embs = torch.stack([
    encode_text(verbalise(lbl), max_length=32)
    for lbl in labels
])

# ── 4) Prediction function ──────────────────────────────────────────────────
def predict_disease(symptoms: str) -> str:
    symptoms = symptoms.strip()
    if not symptoms:
        return "❗️ Please enter your symptoms."
    try:
        # embed user input
        q_emb = encode_text(symptoms).unsqueeze(0)  # [1, hidden]
        # cosine with each label embedding
        sims  = cosine_similarity(label_embs, q_emb, dim=1)  # [num_labels]
        idx   = sims.argmax().item()
        return labels[idx]
    except Exception as e:
        return f"Error: {e}"

# ── 5) Gradio App ───────────────────────────────────────────────────────────
app = gr.Interface(
    fn=predict_disease,
    inputs=gr.Textbox(
        lines=3,
        placeholder="Enter your symptoms here…"
    ),
    outputs="text",
    title="πŸ”¬ Symptomβ†’Disease Chatbot",
    description="PubMed-BERT frozen embeddings + cosine similarity"
)

if __name__ == "__main__":
    app.launch(server_name="0.0.0.0",
               server_port=int(os.environ.get("PORT", 7860)),
               share=False)