Spaces:
Sleeping
Sleeping
File size: 4,077 Bytes
c3c7cca e1e0f62 d1af83e e1e0f62 d1af83e e1e0f62 d1af83e e1e0f62 d1af83e e1e0f62 d1af83e e1e0f62 14b56fb e1e0f62 d1af83e e1e0f62 d1af83e e1e0f62 d1af83e e1e0f62 d1af83e e1e0f62 d1af83e e1e0f62 14b56fb e1e0f62 14b56fb d1af83e e1e0f62 c3c7cca d1af83e c3c7cca d1af83e c3c7cca e1e0f62 d1af83e e1e0f62 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
import os
import re, random, hashlib
import pandas as pd
import numpy as np
import torch
import transformers
import gradio as gr
from torch import nn
from torch.nn.functional import cosine_similarity
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MIN_FREQ = 4
MAX_LEN = 256
VERBALIZE_LABEL = True
# ββ 1) Load & Clean Data βββββββββββββββββββββββββββββββββββββββββββββββββββββ
df = pd.read_csv("medquad.csv")
# build text field
df["text"] = df["question"].fillna("").str.strip() + " " + df["answer"].fillna("").str.strip()
df = df.dropna(subset=["text"]).reset_index(drop=True)
# normalize hyphens/spaces in both text and labels
dash_pat = r"[-β-ββ]"
df["text"] = df["text"].str.replace(dash_pat, " ", regex=True)
df["focus_area"] = (
df["focus_area"]
.fillna("")
.astype(str)
.str.replace(dash_pat, " ", regex=True)
.str.lower()
.str.replace(r"\s+", " ", regex=True)
.str.strip()
)
# prune rare labels
vc = df["focus_area"].value_counts()
keep = vc[vc >= MIN_FREQ].index
df = df[df["focus_area"].isin(keep)].reset_index(drop=True)
# ββ 2) Tokenizer & Frozen BERT βββββββββββββββββββββββββββββββββββββββββββββββ
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME) \
.to(DEVICE).eval()
# ββ 3) Label β ID maps & label embeddings ββββββββββββββββββββββββββββββββββββ
def verbalise(lbl: str) -> str:
if VERBALIZE_LABEL:
return f"This question is about the medical focus area of {lbl}."
return lbl
labels = sorted(df["focus_area"].unique())
label2id = {lbl:i for i,lbl in enumerate(labels)}
id2label = {i:lbl for lbl,i in label2id.items()}
@torch.no_grad()
def encode_text(s: str, max_length=MAX_LEN) -> torch.Tensor:
toks = tokenizer(s, return_tensors="pt",
truncation=True, max_length=max_length,
padding=False).to(DEVICE)
out = bert_model(**toks).last_hidden_state[:,0] # CLS
return out.squeeze().cpu()
# precompute one vector per label
label_embs = torch.stack([
encode_text(verbalise(lbl), max_length=32)
for lbl in labels
])
# ββ 4) Prediction function ββββββββββββββββββββββββββββββββββββββββββββββββββ
def predict_disease(symptoms: str) -> str:
symptoms = symptoms.strip()
if not symptoms:
return "βοΈ Please enter your symptoms."
try:
# embed user input
q_emb = encode_text(symptoms).unsqueeze(0) # [1, hidden]
# cosine with each label embedding
sims = cosine_similarity(label_embs, q_emb, dim=1) # [num_labels]
idx = sims.argmax().item()
return labels[idx]
except Exception as e:
return f"Error: {e}"
# ββ 5) Gradio App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
app = gr.Interface(
fn=predict_disease,
inputs=gr.Textbox(
lines=3,
placeholder="Enter your symptoms hereβ¦"
),
outputs="text",
title="π¬ SymptomβDisease Chatbot",
description="PubMed-BERT frozen embeddings + cosine similarity"
)
if __name__ == "__main__":
app.launch(server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
share=False)
|