Spaces:
Sleeping
Sleeping
import os | |
import re, random, hashlib | |
import pandas as pd | |
import numpy as np | |
import torch | |
import transformers | |
import gradio as gr | |
from torch import nn | |
from torch.nn.functional import cosine_similarity | |
# ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext" | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
MIN_FREQ = 4 | |
MAX_LEN = 256 | |
VERBALIZE_LABEL = True | |
# ββ 1) Load & Clean Data βββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
df = pd.read_csv("medquad.csv") | |
# build text field | |
df["text"] = df["question"].fillna("").str.strip() + " " + df["answer"].fillna("").str.strip() | |
df = df.dropna(subset=["text"]).reset_index(drop=True) | |
# normalize hyphens/spaces in both text and labels | |
dash_pat = r"[-β-ββ]" | |
df["text"] = df["text"].str.replace(dash_pat, " ", regex=True) | |
df["focus_area"] = ( | |
df["focus_area"] | |
.fillna("") | |
.astype(str) | |
.str.replace(dash_pat, " ", regex=True) | |
.str.lower() | |
.str.replace(r"\s+", " ", regex=True) | |
.str.strip() | |
) | |
# prune rare labels | |
vc = df["focus_area"].value_counts() | |
keep = vc[vc >= MIN_FREQ].index | |
df = df[df["focus_area"].isin(keep)].reset_index(drop=True) | |
# ββ 2) Tokenizer & Frozen BERT βββββββββββββββββββββββββββββββββββββββββββββββ | |
tokenizer = transformers.AutoTokenizer.from_pretrained(MODEL_NAME) | |
bert_model = transformers.AutoModel.from_pretrained(MODEL_NAME) \ | |
.to(DEVICE).eval() | |
# ββ 3) Label β ID maps & label embeddings ββββββββββββββββββββββββββββββββββββ | |
def verbalise(lbl: str) -> str: | |
if VERBALIZE_LABEL: | |
return f"This question is about the medical focus area of {lbl}." | |
return lbl | |
labels = sorted(df["focus_area"].unique()) | |
label2id = {lbl:i for i,lbl in enumerate(labels)} | |
id2label = {i:lbl for lbl,i in label2id.items()} | |
def encode_text(s: str, max_length=MAX_LEN) -> torch.Tensor: | |
toks = tokenizer(s, return_tensors="pt", | |
truncation=True, max_length=max_length, | |
padding=False).to(DEVICE) | |
out = bert_model(**toks).last_hidden_state[:,0] # CLS | |
return out.squeeze().cpu() | |
# precompute one vector per label | |
label_embs = torch.stack([ | |
encode_text(verbalise(lbl), max_length=32) | |
for lbl in labels | |
]) | |
# ββ 4) Prediction function ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
def predict_disease(symptoms: str) -> str: | |
symptoms = symptoms.strip() | |
if not symptoms: | |
return "βοΈ Please enter your symptoms." | |
try: | |
# embed user input | |
q_emb = encode_text(symptoms).unsqueeze(0) # [1, hidden] | |
# cosine with each label embedding | |
sims = cosine_similarity(label_embs, q_emb, dim=1) # [num_labels] | |
idx = sims.argmax().item() | |
return labels[idx] | |
except Exception as e: | |
return f"Error: {e}" | |
# ββ 5) Gradio App βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
app = gr.Interface( | |
fn=predict_disease, | |
inputs=gr.Textbox( | |
lines=3, | |
placeholder="Enter your symptoms hereβ¦" | |
), | |
outputs="text", | |
title="π¬ SymptomβDisease Chatbot", | |
description="PubMed-BERT frozen embeddings + cosine similarity" | |
) | |
if __name__ == "__main__": | |
app.launch(server_name="0.0.0.0", | |
server_port=int(os.environ.get("PORT", 7860)), | |
share=False) | |