paragon-analytics commited on
Commit
d2b9b3e
Β·
verified Β·
1 Parent(s): 41b32e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -51
app.py CHANGED
@@ -9,15 +9,14 @@ from transformers import (
9
  )
10
  import gradio as gr
11
 
12
- # 1) Device setup
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
- # 2) Load ADR classifier model & tokenizer
16
  model_name = "paragon-analytics/ADRv1"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
19
 
20
- # 3) Build HF text-classification pipeline
21
  pred_pipeline = pipeline(
22
  "text-classification",
23
  model=model,
@@ -26,40 +25,30 @@ pred_pipeline = pipeline(
26
  device=0 if device.type == "cuda" else -1
27
  )
28
 
29
- # 4) Base predict_proba: List[str] β†’ np.ndarray of shape (n_samples, n_classes)
30
  def predict_proba(texts):
31
  if isinstance(texts, str):
32
  texts = [texts]
33
  results = pred_pipeline(texts)
34
- # results: List[List[{"label":…, "score":…}]]
35
- probs = np.array([[d["score"] for d in sample] for sample in results])
36
- return probs
37
 
38
- # 5) SHAP-compatible wrapper: joins token lists back into strings
39
  def predict_proba_shap(inputs):
40
- # inputs: List[str] or List[List[str]]
41
- texts = [
42
- " ".join(x) if isinstance(x, list) else x
43
- for x in inputs
44
- ]
45
  return predict_proba(texts)
46
 
47
- # 6) Instantiate SHAP explainer with a Text masker
48
  masker = shap.maskers.Text(tokenizer)
49
- # Grab output class labels from a dummy sample
50
  _example = pred_pipeline(["test"])[0]
51
  class_labels = [d["label"] for d in _example]
52
-
53
  explainer = shap.Explainer(
54
  predict_proba_shap,
55
  masker=masker,
56
  output_names=class_labels
57
  )
58
 
59
- # 7) Load biomedical NER model & pipeline
60
- ner_model_name = "d4data/biomedical-ner-all"
61
- ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name)
62
- ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device)
63
  ner_pipe = pipeline(
64
  "ner",
65
  model=ner_model,
@@ -68,7 +57,6 @@ ner_pipe = pipeline(
68
  device=0 if device.type == "cuda" else -1
69
  )
70
 
71
- # 8) Mapping for entity highlight colors
72
  ENTITY_COLORS = {
73
  "Severity": "red",
74
  "Sign_symptom": "green",
@@ -79,57 +67,49 @@ ENTITY_COLORS = {
79
  "Biological_structure": "silver"
80
  }
81
 
82
- # 9) Full predict + explain + NER function
83
  def adr_predict(text: str):
84
- # a) Predict probabilities
85
  probs = predict_proba([text])[0]
86
- prob_dict = {label: float(probs[i]) for i, label in enumerate(class_labels)}
87
-
88
- # b) SHAP explanation β†’ Matplotlib figure
89
- shap_values = explainer([text])
90
- fig = shap.plots.text(shap_values[0], display=False)
91
-
92
- # c) NER highlighting
93
  ents = ner_pipe(text)
94
- highlighted = ""
95
- last_idx = 0
96
  for ent in ents:
97
- start, end = ent["start"], ent["end"]
98
- word = ent["word"].replace("##", "")
99
  color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
100
- highlighted += (
101
- text[last_idx:start]
102
- + f"<mark style='background-color:{color};'>{word}</mark>"
103
- )
104
- last_idx = end
105
- highlighted += text[last_idx:]
106
-
107
  return prob_dict, fig, highlighted
108
 
109
- # 10) Build Gradio UI
110
  with gr.Blocks() as demo:
111
  gr.Markdown("## Welcome to **ADR Detector** πŸͺ")
112
  gr.Markdown(
113
- "Predicts the likelihood your text describes a **severe** vs. **non-severe** adverse reaction. \n"
114
  "_(Not for medical or diagnostic use.)_"
115
  )
116
 
117
  txt = gr.Textbox(
118
- label="Enter Your Text Here:",
119
- lines=3,
120
  placeholder="Type a sentence about an adverse reaction…"
121
  )
122
  btn = gr.Button("Analyze")
123
 
124
  with gr.Row():
125
- label_out = gr.Label(label="Predicted Probabilities")
126
- shap_out = gr.Plot(label="SHAP Explanation")
127
- ner_out = gr.HTML(label="Biomedical Entities Highlighted")
128
 
129
  btn.click(
130
  fn=adr_predict,
131
  inputs=txt,
132
- outputs=[label_out, shap_out, ner_out]
133
  )
134
 
135
  gr.Examples(
@@ -138,9 +118,9 @@ with gr.Blocks() as demo:
138
  "A 35-year-old female had minor abdominal pain after Acetaminophen."
139
  ],
140
  inputs=txt,
141
- outputs=[label_out, shap_out, ner_out],
142
  fn=adr_predict,
143
- cache_examples=True
144
  )
145
 
146
  if __name__ == "__main__":
 
9
  )
10
  import gradio as gr
11
 
12
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€” 1) Device setup β€”β€”β€”β€”β€”β€”β€”β€”β€”
13
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
 
15
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€” 2) ADR classifier β€”β€”β€”β€”β€”β€”β€”β€”β€”
16
  model_name = "paragon-analytics/ADRv1"
17
  tokenizer = AutoTokenizer.from_pretrained(model_name)
18
  model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
19
 
 
20
  pred_pipeline = pipeline(
21
  "text-classification",
22
  model=model,
 
25
  device=0 if device.type == "cuda" else -1
26
  )
27
 
 
28
  def predict_proba(texts):
29
  if isinstance(texts, str):
30
  texts = [texts]
31
  results = pred_pipeline(texts)
32
+ return np.array([[d["score"] for d in sample] for sample in results])
 
 
33
 
 
34
  def predict_proba_shap(inputs):
35
+ texts = [" ".join(x) if isinstance(x, list) else x for x in inputs]
 
 
 
 
36
  return predict_proba(texts)
37
 
38
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€” 3) SHAP explainer β€”β€”β€”β€”β€”β€”β€”β€”β€”
39
  masker = shap.maskers.Text(tokenizer)
 
40
  _example = pred_pipeline(["test"])[0]
41
  class_labels = [d["label"] for d in _example]
 
42
  explainer = shap.Explainer(
43
  predict_proba_shap,
44
  masker=masker,
45
  output_names=class_labels
46
  )
47
 
48
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€” 4) Biomedical NER β€”β€”β€”β€”β€”β€”β€”β€”β€”
49
+ ner_name = "d4data/biomedical-ner-all"
50
+ ner_tokenizer = AutoTokenizer.from_pretrained(ner_name)
51
+ ner_model = AutoModelForTokenClassification.from_pretrained(ner_name).to(device)
52
  ner_pipe = pipeline(
53
  "ner",
54
  model=ner_model,
 
57
  device=0 if device.type == "cuda" else -1
58
  )
59
 
 
60
  ENTITY_COLORS = {
61
  "Severity": "red",
62
  "Sign_symptom": "green",
 
67
  "Biological_structure": "silver"
68
  }
69
 
70
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€” 5) Prediction + SHAP + NER β€”β€”β€”β€”β€”β€”β€”β€”β€”
71
  def adr_predict(text: str):
72
+ # Probabilities
73
  probs = predict_proba([text])[0]
74
+ prob_dict = {cls: float(probs[i]) for i, cls in enumerate(class_labels)}
75
+ # SHAP
76
+ shap_vals = explainer([text])
77
+ fig = shap.plots.text(shap_vals[0], display=False)
78
+ # NER highlight
 
 
79
  ents = ner_pipe(text)
80
+ highlighted, last = "", 0
 
81
  for ent in ents:
82
+ s, e = ent["start"], ent["end"]
83
+ w = ent["word"].replace("##", "")
84
  color = ENTITY_COLORS.get(ent["entity_group"], "lightgray")
85
+ highlighted += text[last:s] + f"<mark style='background-color:{color};'>{w}</mark>"
86
+ last = e
87
+ highlighted += text[last:]
 
 
 
 
88
  return prob_dict, fig, highlighted
89
 
90
+ # β€”β€”β€”β€”β€”β€”β€”β€”β€” 6) Gradio UI β€”β€”β€”β€”β€”β€”β€”β€”β€”
91
  with gr.Blocks() as demo:
92
  gr.Markdown("## Welcome to **ADR Detector** πŸͺ")
93
  gr.Markdown(
94
+ "Predicts how likely your text describes a **severe** vs. **non-severe** adverse reaction. \n"
95
  "_(Not for medical or diagnostic use.)_"
96
  )
97
 
98
  txt = gr.Textbox(
99
+ label="Enter Your Text Here:", lines=3,
 
100
  placeholder="Type a sentence about an adverse reaction…"
101
  )
102
  btn = gr.Button("Analyze")
103
 
104
  with gr.Row():
105
+ out_prob = gr.Label(label="Predicted Probabilities")
106
+ out_shap = gr.Plot(label="SHAP Explanation")
107
+ out_ner = gr.HTML(label="Biomedical Entities Highlighted")
108
 
109
  btn.click(
110
  fn=adr_predict,
111
  inputs=txt,
112
+ outputs=[out_prob, out_shap, out_ner]
113
  )
114
 
115
  gr.Examples(
 
118
  "A 35-year-old female had minor abdominal pain after Acetaminophen."
119
  ],
120
  inputs=txt,
121
+ outputs=[out_prob, out_shap, out_ner],
122
  fn=adr_predict,
123
+ cache_examples=False # ← disable startup caching here
124
  )
125
 
126
  if __name__ == "__main__":