paragon-analytics commited on
Commit
41b32e2
·
verified ·
1 Parent(s): b8029de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +70 -44
app.py CHANGED
@@ -1,85 +1,102 @@
1
  import numpy as np
2
  import torch
3
  import shap
4
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification
5
-
 
 
 
 
6
  import gradio as gr
7
 
8
  # 1) Device setup
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
11
- # 2) Load ADR classifier
12
  model_name = "paragon-analytics/ADRv1"
13
  tokenizer = AutoTokenizer.from_pretrained(model_name)
14
  model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
15
 
16
- # 3) Hugging Face textclassification pipeline with return_all_scores
17
  pred_pipeline = pipeline(
18
  "text-classification",
19
  model=model,
20
  tokenizer=tokenizer,
21
  return_all_scores=True,
22
- device=0 if device == "cuda" else -1
23
  )
24
 
25
- # 4) Wrapper: list[str]→np.ndarray of shape (n, n_classes)
26
  def predict_proba(texts):
27
  if isinstance(texts, str):
28
  texts = [texts]
29
  results = pred_pipeline(texts)
30
- # results is List[List[{"label":…, "score":…}]]
31
  probs = np.array([[d["score"] for d in sample] for sample in results])
32
  return probs
33
 
34
- # 5) Build SHAP explainer
35
- masker = shap.maskers.Text(tokenizer) # for text explainability
36
- # get output names from a dummy call
37
- example = pred_pipeline(["test"])[0]
38
- class_labels = [d["label"] for d in example]
 
 
 
 
 
 
 
 
 
 
39
  explainer = shap.Explainer(
40
- predict_proba,
41
  masker=masker,
42
  output_names=class_labels
43
  )
44
 
45
- # 6) Load biomedical NER pipeline
46
- ner_tokenizer = AutoTokenizer.from_pretrained("d4data/biomedical-ner-all")
47
- ner_model = AutoModelForTokenClassification.from_pretrained("d4data/biomedical-ner-all")
 
48
  ner_pipe = pipeline(
49
  "ner",
50
  model=ner_model,
51
  tokenizer=ner_tokenizer,
52
  aggregation_strategy="simple",
53
- device=0 if device == "cuda" else -1
54
  )
55
 
56
- # 7) Single‐text prediction + SHAP + NER
57
- def adr_predict(text):
 
 
 
 
 
 
 
 
 
 
 
58
  # a) Predict probabilities
59
- probs = predict_proba(text)[0]
60
  prob_dict = {label: float(probs[i]) for i, label in enumerate(class_labels)}
61
 
62
- # b) SHAP explanation (returns a Matplotlib figure)
63
  shap_values = explainer([text])
64
  fig = shap.plots.text(shap_values[0], display=False)
65
 
66
  # c) NER highlighting
67
- entities = ner_pipe(text)
68
- colors = {
69
- "Severity": "red",
70
- "Sign_symptom": "green",
71
- "Medication": "lightblue",
72
- "Age": "yellow",
73
- "Sex": "yellow",
74
- "Diagnostic_procedure": "gray",
75
- "Biological_structure": "silver"
76
- }
77
  highlighted = ""
78
  last_idx = 0
79
- for ent in entities:
80
  start, end = ent["start"], ent["end"]
81
  word = ent["word"].replace("##", "")
82
- color = colors.get(ent["entity_group"], "lightgray")
83
  highlighted += (
84
  text[last_idx:start]
85
  + f"<mark style='background-color:{color};'>{word}</mark>"
@@ -89,23 +106,31 @@ def adr_predict(text):
89
 
90
  return prob_dict, fig, highlighted
91
 
92
- # 8) Gradio UI
93
  with gr.Blocks() as demo:
94
  gr.Markdown("## Welcome to **ADR Detector** 🪐")
95
  gr.Markdown(
96
- "Predicts the likelihood your text describes a severe vs. non-severe adverse reaction. "
97
- "_(Not for medical diagnosis.)_"
98
  )
99
 
100
- txt = gr.Textbox(label="Enter Your Text Here:", lines=3, placeholder="Type a sentence about a reaction…")
 
 
 
 
101
  btn = gr.Button("Analyze")
102
 
103
  with gr.Row():
104
- lbl = gr.Label(label="Predicted Probabilities")
105
- shp = gr.Plot(label="SHAP Explanation")
106
- ner = gr.HTML(label="Biomedical Entities Highlighted")
107
 
108
- btn.click(fn=adr_predict, inputs=txt, outputs=[lbl, shp, ner])
 
 
 
 
109
 
110
  gr.Examples(
111
  examples=[
@@ -113,9 +138,10 @@ with gr.Blocks() as demo:
113
  "A 35-year-old female had minor abdominal pain after Acetaminophen."
114
  ],
115
  inputs=txt,
116
- outputs=[lbl, shp, ner],
117
  fn=adr_predict,
118
  cache_examples=True
119
  )
120
 
121
- demo.launch()
 
 
1
  import numpy as np
2
  import torch
3
  import shap
4
+ from transformers import (
5
+ pipeline,
6
+ AutoTokenizer,
7
+ AutoModelForSequenceClassification,
8
+ AutoModelForTokenClassification
9
+ )
10
  import gradio as gr
11
 
12
  # 1) Device setup
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
+ # 2) Load ADR classifier model & tokenizer
16
  model_name = "paragon-analytics/ADRv1"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
19
 
20
+ # 3) Build HF text-classification pipeline
21
  pred_pipeline = pipeline(
22
  "text-classification",
23
  model=model,
24
  tokenizer=tokenizer,
25
  return_all_scores=True,
26
+ device=0 if device.type == "cuda" else -1
27
  )
28
 
29
+ # 4) Base predict_proba: List[str] np.ndarray of shape (n_samples, n_classes)
30
  def predict_proba(texts):
31
  if isinstance(texts, str):
32
  texts = [texts]
33
  results = pred_pipeline(texts)
34
+ # results: List[List[{"label":…, "score":…}]]
35
  probs = np.array([[d["score"] for d in sample] for sample in results])
36
  return probs
37
 
38
+ # 5) SHAP-compatible wrapper: joins token lists back into strings
39
+ def predict_proba_shap(inputs):
40
+ # inputs: List[str] or List[List[str]]
41
+ texts = [
42
+ " ".join(x) if isinstance(x, list) else x
43
+ for x in inputs
44
+ ]
45
+ return predict_proba(texts)
46
+
47
+ # 6) Instantiate SHAP explainer with a Text masker
48
+ masker = shap.maskers.Text(tokenizer)
49
+ # Grab output class labels from a dummy sample
50
+ _example = pred_pipeline(["test"])[0]
51
+ class_labels = [d["label"] for d in _example]
52
+
53
  explainer = shap.Explainer(
54
+ predict_proba_shap,
55
  masker=masker,
56
  output_names=class_labels
57
  )
58
 
59
+ # 7) Load biomedical NER model & pipeline
60
+ ner_model_name = "d4data/biomedical-ner-all"
61
+ ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
62
+ ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device)
63
  ner_pipe = pipeline(
64
  "ner",
65
  model=ner_model,
66
  tokenizer=ner_tokenizer,
67
  aggregation_strategy="simple",
68
+ device=0 if device.type == "cuda" else -1
69
  )
70
 
71
+ # 8) Mapping for entity highlight colors
72
+ ENTITY_COLORS = {
73
+ "Severity": "red",
74
+ "Sign_symptom": "green",
75
+ "Medication": "lightblue",
76
+ "Age": "yellow",
77
+ "Sex": "yellow",
78
+ "Diagnostic_procedure": "gray",
79
+ "Biological_structure": "silver"
80
+ }
81
+
82
+ # 9) Full predict + explain + NER function
83
+ def adr_predict(text: str):
84
  # a) Predict probabilities
85
+ probs = predict_proba([text])[0]
86
  prob_dict = {label: float(probs[i]) for i, label in enumerate(class_labels)}
87
 
88
+ # b) SHAP explanation Matplotlib figure
89
  shap_values = explainer([text])
90
  fig = shap.plots.text(shap_values[0], display=False)
91
 
92
  # c) NER highlighting
93
+ ents = ner_pipe(text)
 
 
 
 
 
 
 
 
 
94
  highlighted = ""
95
  last_idx = 0
96
+ for ent in ents:
97
  start, end = ent["start"], ent["end"]
98
  word = ent["word"].replace("##", "")
99
+ color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
100
  highlighted += (
101
  text[last_idx:start]
102
  + f"<mark style='background-color:{color};'>{word}</mark>"
 
106
 
107
  return prob_dict, fig, highlighted
108
 
109
+ # 10) Build Gradio UI
110
  with gr.Blocks() as demo:
111
  gr.Markdown("## Welcome to **ADR Detector** 🪐")
112
  gr.Markdown(
113
+ "Predicts the likelihood your text describes a **severe** vs. **non-severe** adverse reaction. \n"
114
+ "_(Not for medical or diagnostic use.)_"
115
  )
116
 
117
+ txt = gr.Textbox(
118
+ label="Enter Your Text Here:",
119
+ lines=3,
120
+ placeholder="Type a sentence about an adverse reaction…"
121
+ )
122
  btn = gr.Button("Analyze")
123
 
124
  with gr.Row():
125
+ label_out = gr.Label(label="Predicted Probabilities")
126
+ shap_out = gr.Plot(label="SHAP Explanation")
127
+ ner_out = gr.HTML(label="Biomedical Entities Highlighted")
128
 
129
+ btn.click(
130
+ fn=adr_predict,
131
+ inputs=txt,
132
+ outputs=[label_out, shap_out, ner_out]
133
+ )
134
 
135
  gr.Examples(
136
  examples=[
 
138
  "A 35-year-old female had minor abdominal pain after Acetaminophen."
139
  ],
140
  inputs=txt,
141
+ outputs=[label_out, shap_out, ner_out],
142
  fn=adr_predict,
143
  cache_examples=True
144
  )
145
 
146
+ if __name__ == "__main__":
147
+ demo.launch()