Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -9,15 +9,14 @@ from transformers import (
|
|
9 |
)
|
10 |
import gradio as gr
|
11 |
|
12 |
-
# 1) Device setup
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
|
15 |
-
# 2)
|
16 |
model_name = "paragon-analytics/ADRv1"
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
|
19 |
|
20 |
-
# 3) Build HF text-classification pipeline
|
21 |
pred_pipeline = pipeline(
|
22 |
"text-classification",
|
23 |
model=model,
|
@@ -26,40 +25,30 @@ pred_pipeline = pipeline(
|
|
26 |
device=0 if device.type == "cuda" else -1
|
27 |
)
|
28 |
|
29 |
-
# 4) Base predict_proba: List[str] β np.ndarray of shape (n_samples, n_classes)
|
30 |
def predict_proba(texts):
|
31 |
if isinstance(texts, str):
|
32 |
texts = [texts]
|
33 |
results = pred_pipeline(texts)
|
34 |
-
|
35 |
-
probs = np.array([[d["score"] for d in sample] for sample in results])
|
36 |
-
return probs
|
37 |
|
38 |
-
# 5) SHAP-compatible wrapper: joins token lists back into strings
|
39 |
def predict_proba_shap(inputs):
|
40 |
-
|
41 |
-
texts = [
|
42 |
-
" ".join(x) if isinstance(x, list) else x
|
43 |
-
for x in inputs
|
44 |
-
]
|
45 |
return predict_proba(texts)
|
46 |
|
47 |
-
#
|
48 |
masker = shap.maskers.Text(tokenizer)
|
49 |
-
# Grab output class labels from a dummy sample
|
50 |
_example = pred_pipeline(["test"])[0]
|
51 |
class_labels = [d["label"] for d in _example]
|
52 |
-
|
53 |
explainer = shap.Explainer(
|
54 |
predict_proba_shap,
|
55 |
masker=masker,
|
56 |
output_names=class_labels
|
57 |
)
|
58 |
|
59 |
-
#
|
60 |
-
|
61 |
-
ner_tokenizer = AutoTokenizer.from_pretrained(
|
62 |
-
ner_model = AutoModelForTokenClassification.from_pretrained(
|
63 |
ner_pipe = pipeline(
|
64 |
"ner",
|
65 |
model=ner_model,
|
@@ -68,7 +57,6 @@ ner_pipe = pipeline(
|
|
68 |
device=0 if device.type == "cuda" else -1
|
69 |
)
|
70 |
|
71 |
-
# 8) Mapping for entity highlight colors
|
72 |
ENTITY_COLORS = {
|
73 |
"Severity": "red",
|
74 |
"Sign_symptom": "green",
|
@@ -79,57 +67,49 @@ ENTITY_COLORS = {
|
|
79 |
"Biological_structure": "silver"
|
80 |
}
|
81 |
|
82 |
-
#
|
83 |
def adr_predict(text: str):
|
84 |
-
#
|
85 |
probs = predict_proba([text])[0]
|
86 |
-
prob_dict = {
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
# c) NER highlighting
|
93 |
ents = ner_pipe(text)
|
94 |
-
highlighted = ""
|
95 |
-
last_idx = 0
|
96 |
for ent in ents:
|
97 |
-
|
98 |
-
|
99 |
color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
|
100 |
-
highlighted +=
|
101 |
-
|
102 |
-
|
103 |
-
)
|
104 |
-
last_idx = end
|
105 |
-
highlighted += text[last_idx:]
|
106 |
-
|
107 |
return prob_dict, fig, highlighted
|
108 |
|
109 |
-
#
|
110 |
with gr.Blocks() as demo:
|
111 |
gr.Markdown("## Welcome to **ADR Detector** πͺ")
|
112 |
gr.Markdown(
|
113 |
-
"Predicts
|
114 |
"_(Not for medical or diagnostic use.)_"
|
115 |
)
|
116 |
|
117 |
txt = gr.Textbox(
|
118 |
-
label="Enter Your Text Here:",
|
119 |
-
lines=3,
|
120 |
placeholder="Type a sentence about an adverse reactionβ¦"
|
121 |
)
|
122 |
btn = gr.Button("Analyze")
|
123 |
|
124 |
with gr.Row():
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
|
129 |
btn.click(
|
130 |
fn=adr_predict,
|
131 |
inputs=txt,
|
132 |
-
outputs=[
|
133 |
)
|
134 |
|
135 |
gr.Examples(
|
@@ -138,9 +118,9 @@ with gr.Blocks() as demo:
|
|
138 |
"A 35-year-old female had minor abdominal pain after Acetaminophen."
|
139 |
],
|
140 |
inputs=txt,
|
141 |
-
outputs=[
|
142 |
fn=adr_predict,
|
143 |
-
cache_examples=
|
144 |
)
|
145 |
|
146 |
if __name__ == "__main__":
|
|
|
9 |
)
|
10 |
import gradio as gr
|
11 |
|
12 |
+
# βββββββββ 1) Device setup βββββββββ
|
13 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
|
15 |
+
# βββββββββ 2) ADR classifier βββββββββ
|
16 |
model_name = "paragon-analytics/ADRv1"
|
17 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
18 |
model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
|
19 |
|
|
|
20 |
pred_pipeline = pipeline(
|
21 |
"text-classification",
|
22 |
model=model,
|
|
|
25 |
device=0 if device.type == "cuda" else -1
|
26 |
)
|
27 |
|
|
|
28 |
def predict_proba(texts):
|
29 |
if isinstance(texts, str):
|
30 |
texts = [texts]
|
31 |
results = pred_pipeline(texts)
|
32 |
+
return np.array([[d["score"] for d in sample] for sample in results])
|
|
|
|
|
33 |
|
|
|
34 |
def predict_proba_shap(inputs):
|
35 |
+
texts = [" ".join(x) if isinstance(x, list) else x for x in inputs]
|
|
|
|
|
|
|
|
|
36 |
return predict_proba(texts)
|
37 |
|
38 |
+
# βββββββββ 3) SHAP explainer βββββββββ
|
39 |
masker = shap.maskers.Text(tokenizer)
|
|
|
40 |
_example = pred_pipeline(["test"])[0]
|
41 |
class_labels = [d["label"] for d in _example]
|
|
|
42 |
explainer = shap.Explainer(
|
43 |
predict_proba_shap,
|
44 |
masker=masker,
|
45 |
output_names=class_labels
|
46 |
)
|
47 |
|
48 |
+
# βββββββββ 4) Biomedical NER βββββββββ
|
49 |
+
ner_name = "d4data/biomedical-ner-all"
|
50 |
+
ner_tokenizer = AutoTokenizer.from_pretrained(ner_name)
|
51 |
+
ner_model = AutoModelForTokenClassification.from_pretrained(ner_name).to(device)
|
52 |
ner_pipe = pipeline(
|
53 |
"ner",
|
54 |
model=ner_model,
|
|
|
57 |
device=0 if device.type == "cuda" else -1
|
58 |
)
|
59 |
|
|
|
60 |
ENTITY_COLORS = {
|
61 |
"Severity": "red",
|
62 |
"Sign_symptom": "green",
|
|
|
67 |
"Biological_structure": "silver"
|
68 |
}
|
69 |
|
70 |
+
# βββββββββ 5) Prediction + SHAP + NER βββββββββ
|
71 |
def adr_predict(text: str):
|
72 |
+
# Probabilities
|
73 |
probs = predict_proba([text])[0]
|
74 |
+
prob_dict = {cls: float(probs[i]) for i, cls in enumerate(class_labels)}
|
75 |
+
# SHAP
|
76 |
+
shap_vals = explainer([text])
|
77 |
+
fig = shap.plots.text(shap_vals[0], display=False)
|
78 |
+
# NER highlight
|
|
|
|
|
79 |
ents = ner_pipe(text)
|
80 |
+
highlighted, last = "", 0
|
|
|
81 |
for ent in ents:
|
82 |
+
s, e = ent["start"], ent["end"]
|
83 |
+
w = ent["word"].replace("##", "")
|
84 |
color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
|
85 |
+
highlighted += text[last:s] + f"<mark style='background-color:{color};'>{w}</mark>"
|
86 |
+
last = e
|
87 |
+
highlighted += text[last:]
|
|
|
|
|
|
|
|
|
88 |
return prob_dict, fig, highlighted
|
89 |
|
90 |
+
# βββββββββ 6) Gradio UI βββββββββ
|
91 |
with gr.Blocks() as demo:
|
92 |
gr.Markdown("## Welcome to **ADR Detector** πͺ")
|
93 |
gr.Markdown(
|
94 |
+
"Predicts how likely your text describes a **severe** vs. **non-severe** adverse reaction. \n"
|
95 |
"_(Not for medical or diagnostic use.)_"
|
96 |
)
|
97 |
|
98 |
txt = gr.Textbox(
|
99 |
+
label="Enter Your Text Here:", lines=3,
|
|
|
100 |
placeholder="Type a sentence about an adverse reactionβ¦"
|
101 |
)
|
102 |
btn = gr.Button("Analyze")
|
103 |
|
104 |
with gr.Row():
|
105 |
+
out_prob = gr.Label(label="Predicted Probabilities")
|
106 |
+
out_shap = gr.Plot(label="SHAP Explanation")
|
107 |
+
out_ner = gr.HTML(label="Biomedical Entities Highlighted")
|
108 |
|
109 |
btn.click(
|
110 |
fn=adr_predict,
|
111 |
inputs=txt,
|
112 |
+
outputs=[out_prob, out_shap, out_ner]
|
113 |
)
|
114 |
|
115 |
gr.Examples(
|
|
|
118 |
"A 35-year-old female had minor abdominal pain after Acetaminophen."
|
119 |
],
|
120 |
inputs=txt,
|
121 |
+
outputs=[out_prob, out_shap, out_ner],
|
122 |
fn=adr_predict,
|
123 |
+
cache_examples=False # β disable startup caching here
|
124 |
)
|
125 |
|
126 |
if __name__ == "__main__":
|