Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -2,14 +2,13 @@ import gradio as gr
|
|
2 |
import torch
|
3 |
import joblib
|
4 |
import numpy as np
|
|
|
|
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
7 |
-
import matplotlib
|
8 |
-
matplotlib.use("Agg") # If no display environment
|
9 |
import matplotlib.pyplot as plt
|
10 |
import io
|
11 |
from PIL import Image
|
12 |
-
import shap
|
13 |
|
14 |
###############################################################################
|
15 |
# Model Definition
|
@@ -33,35 +32,30 @@ class VirusClassifier(nn.Module):
|
|
33 |
|
34 |
def forward(self, x):
|
35 |
return self.network(x)
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
return out
|
55 |
-
|
56 |
|
57 |
###############################################################################
|
58 |
# Utility Functions
|
59 |
###############################################################################
|
60 |
def parse_fasta(text):
|
61 |
-
"""
|
62 |
-
Parse FASTA text, returning a list of (header, sequence).
|
63 |
-
We'll only use the *first* sequence in practice.
|
64 |
-
"""
|
65 |
sequences = []
|
66 |
current_header = None
|
67 |
current_sequence = []
|
@@ -82,9 +76,7 @@ def parse_fasta(text):
|
|
82 |
return sequences
|
83 |
|
84 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
85 |
-
"""
|
86 |
-
Convert a single sequence to a 4^k dimension k-mer frequency vector.
|
87 |
-
"""
|
88 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
89 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
90 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
@@ -96,223 +88,403 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
96 |
|
97 |
total_kmers = len(sequence) - k + 1
|
98 |
if total_kmers > 0:
|
99 |
-
vec
|
100 |
|
101 |
return vec
|
102 |
|
103 |
###############################################################################
|
104 |
-
#
|
105 |
###############################################################################
|
106 |
-
def
|
107 |
"""
|
108 |
-
|
109 |
-
(shap_values: shape=(num_features,))
|
110 |
-
(base_value: scalar)
|
111 |
-
(data: original input data for that sample, shape=(num_features,))
|
112 |
"""
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
)
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
plt.close(fig)
|
128 |
-
return im
|
129 |
|
130 |
-
def
|
131 |
"""
|
132 |
-
|
133 |
-
|
134 |
-
raw_freq: (256,) unscaled frequency
|
135 |
-
scaled_vec: (256,) scaled frequency (z-scores)
|
136 |
-
kmer_list: list of length=256
|
137 |
"""
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
141 |
|
142 |
-
for idx in top_indices:
|
143 |
-
idx = int(idx) # ensure it's an integer
|
144 |
-
top_data.append({
|
145 |
-
"kmer": kmer_list[idx],
|
146 |
-
"shap": shap_values[idx],
|
147 |
-
"abs_shap": abs_vals[idx],
|
148 |
-
"freq": raw_freq[idx] * 100.0,
|
149 |
-
"sigma": scaled_vec[idx]
|
150 |
-
})
|
151 |
-
|
152 |
-
# Sort again by abs_shap desc
|
153 |
-
top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
|
154 |
-
|
155 |
-
kmers = [d["kmer"] for d in top_data]
|
156 |
-
freqs = [d["freq"] for d in top_data]
|
157 |
-
sigmas = [d["sigma"] for d in top_data]
|
158 |
-
colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
|
159 |
-
|
160 |
x = np.arange(len(kmers))
|
161 |
width = 0.4
|
162 |
|
163 |
-
fig,
|
164 |
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
169 |
|
170 |
-
|
171 |
-
|
172 |
-
ax2.set_ylabel("Standard Deviations (σ)")
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
ax.set_xticks(x)
|
175 |
ax.set_xticklabels(kmers, rotation=45, ha='right')
|
176 |
-
ax.set_title(title)
|
177 |
-
|
178 |
-
|
179 |
-
lines2, labels2 = ax2.get_legend_handles_labels()
|
180 |
-
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
|
181 |
|
182 |
plt.tight_layout()
|
183 |
-
|
184 |
-
fig.savefig(buf, format='png', bbox_inches='tight')
|
185 |
-
buf.seek(0)
|
186 |
-
im = Image.open(buf)
|
187 |
-
plt.close(fig)
|
188 |
-
return im
|
189 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
190 |
|
191 |
###############################################################################
|
192 |
-
#
|
193 |
###############################################################################
|
194 |
-
def
|
195 |
"""
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
200 |
"""
|
201 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
202 |
try:
|
203 |
if isinstance(file_obj, str):
|
204 |
text = file_obj
|
205 |
else:
|
206 |
-
text = file_obj.decode(
|
207 |
-
except:
|
208 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
209 |
|
210 |
-
#
|
211 |
sequences = parse_fasta(text)
|
212 |
-
if
|
213 |
-
return
|
214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
216 |
-
#
|
217 |
k = 4
|
218 |
-
raw_freq_vec = sequence_to_kmer_vector(seq, k=k) # shape=(256,)
|
219 |
-
|
220 |
-
# 4. Load model & scaler
|
221 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
222 |
-
model = VirusClassifier(4**k).to(device)
|
223 |
try:
|
224 |
-
|
|
|
|
|
|
|
|
|
225 |
model.load_state_dict(state_dict)
|
226 |
-
model.eval()
|
227 |
scaler = joblib.load("scaler.pkl")
|
228 |
-
|
229 |
-
return f"Error loading model/scaler: {str(e)}", None, None
|
230 |
-
|
231 |
-
# 5. Scale data
|
232 |
-
scaled_data = scaler.transform(raw_freq_vec.reshape(1, -1)) # shape=(1, 256)
|
233 |
-
X_tensor = torch.FloatTensor(scaled_data).to(device)
|
234 |
-
|
235 |
-
# 6. Predict
|
236 |
-
with torch.no_grad():
|
237 |
-
out = model(X_tensor) # shape=(1,2)
|
238 |
-
probs = torch.softmax(out, dim=1).cpu().numpy()[0] # shape=(2,)
|
239 |
-
pred_class = np.argmax(probs)
|
240 |
-
pred_label = "human" if pred_class == 1 else "non-human"
|
241 |
-
confidence = float(np.max(probs))
|
242 |
-
human_prob = float(probs[1])
|
243 |
-
nonhuman_prob = float(probs[0])
|
244 |
-
|
245 |
-
# 7. SHAP Explanation (single sample)
|
246 |
-
# We'll wrap the model for shap
|
247 |
-
wrapped_model = TorchModelWrapper(model, device)
|
248 |
-
background_data = scaled_data # 1 sample as background (a bit silly, but simpler)
|
249 |
-
explainer = shap.Explainer(wrapped_model, background_data)
|
250 |
-
shap_values = explainer(scaled_data) # shape => (1, 2, 256) for 2-class output
|
251 |
-
|
252 |
-
# We'll pick class=1's shap values
|
253 |
-
sample_shap_vals = shap_values.values[0, 1, :] # shape=(256,)
|
254 |
-
base_value = shap_values.base_values[0, 1]
|
255 |
-
sample_data = shap_values.data[0] # shape=(256,)
|
256 |
-
|
257 |
-
# 8. Create the two plots
|
258 |
-
wf_img = create_waterfall_plot(
|
259 |
-
shap_values=sample_shap_vals,
|
260 |
-
base_value=base_value,
|
261 |
-
data=sample_data,
|
262 |
-
max_display=15
|
263 |
-
)
|
264 |
-
|
265 |
-
# freq-sigma plot
|
266 |
-
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
267 |
-
freq_img = create_freq_sigma_plot(
|
268 |
-
shap_values=sample_shap_vals,
|
269 |
-
raw_freq=raw_freq_vec,
|
270 |
-
scaled_vec=scaled_data[0],
|
271 |
-
kmer_list=kmers,
|
272 |
-
title=f"{header[:25]}... (Top-10 K-mers)"
|
273 |
-
)
|
274 |
-
|
275 |
-
# 9. Markdown result
|
276 |
-
result_md = f"""
|
277 |
-
# Classification Result
|
278 |
-
|
279 |
-
**Header**: {header}
|
280 |
-
|
281 |
-
**Predicted Label**: {pred_label}
|
282 |
-
**Confidence**: {confidence:.4f}
|
283 |
-
|
284 |
-
**Human Probability**: {human_prob:.4f}
|
285 |
-
**Non-human Probability**: {nonhuman_prob:.4f}
|
286 |
|
287 |
-
|
288 |
-
|
289 |
|
290 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
291 |
|
292 |
|
293 |
###############################################################################
|
294 |
# Gradio Interface
|
295 |
###############################################################################
|
296 |
-
with gr.Blocks(title="
|
297 |
-
gr.Markdown(
|
298 |
-
|
299 |
-
|
300 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
|
|
302 |
with gr.Tabs():
|
303 |
-
with gr.Tab("Results"):
|
304 |
md_out = gr.Markdown()
|
305 |
-
with gr.Tab("
|
306 |
-
|
307 |
-
with gr.Tab("
|
308 |
-
|
309 |
-
|
310 |
-
|
311 |
-
|
312 |
-
|
313 |
-
|
|
|
|
|
|
|
|
|
|
|
314 |
)
|
315 |
|
316 |
-
# No share=True -> avoid HF Spaces warning
|
317 |
if __name__ == "__main__":
|
318 |
-
|
|
|
|
2 |
import torch
|
3 |
import joblib
|
4 |
import numpy as np
|
5 |
+
import shap
|
6 |
+
import random
|
7 |
from itertools import product
|
8 |
import torch.nn as nn
|
|
|
|
|
9 |
import matplotlib.pyplot as plt
|
10 |
import io
|
11 |
from PIL import Image
|
|
|
12 |
|
13 |
###############################################################################
|
14 |
# Model Definition
|
|
|
32 |
|
33 |
def forward(self, x):
|
34 |
return self.network(x)
|
35 |
+
|
36 |
+
def get_feature_importance(self, x):
|
37 |
+
"""
|
38 |
+
Calculate gradient-based feature importance, specifically for the
|
39 |
+
'human' class (index=1) by computing gradient of that probability wrt x.
|
40 |
+
"""
|
41 |
+
x.requires_grad_(True)
|
42 |
+
output = self.network(x)
|
43 |
+
probs = torch.softmax(output, dim=1)
|
44 |
+
|
45 |
+
# Probability of 'human' class (index=1)
|
46 |
+
human_prob = probs[..., 1]
|
47 |
+
if x.grad is not None:
|
48 |
+
x.grad.zero_()
|
49 |
+
human_prob.backward()
|
50 |
+
importance = x.grad # shape: (batch_size, n_features)
|
51 |
+
|
52 |
+
return importance, float(human_prob)
|
|
|
|
|
53 |
|
54 |
###############################################################################
|
55 |
# Utility Functions
|
56 |
###############################################################################
|
57 |
def parse_fasta(text):
|
58 |
+
"""Parses text input in FASTA format into a list of (header, sequence)."""
|
|
|
|
|
|
|
59 |
sequences = []
|
60 |
current_header = None
|
61 |
current_sequence = []
|
|
|
76 |
return sequences
|
77 |
|
78 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
79 |
+
"""Convert a single nucleotide sequence to a k-mer frequency vector."""
|
|
|
|
|
80 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
81 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
82 |
vec = np.zeros(len(kmers), dtype=np.float32)
|
|
|
88 |
|
89 |
total_kmers = len(sequence) - k + 1
|
90 |
if total_kmers > 0:
|
91 |
+
vec = vec / total_kmers # normalize frequencies
|
92 |
|
93 |
return vec
|
94 |
|
95 |
###############################################################################
|
96 |
+
# Additional Plots
|
97 |
###############################################################################
|
98 |
+
def create_probability_bar_plot(prob_human, prob_nonhuman):
|
99 |
"""
|
100 |
+
Simple bar plot comparing human vs. non-human probabilities.
|
|
|
|
|
|
|
101 |
"""
|
102 |
+
labels = ["Non-human", "Human"]
|
103 |
+
probs = [prob_nonhuman, prob_human]
|
104 |
+
colors = ["red", "green"]
|
105 |
+
|
106 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
107 |
+
ax.bar(labels, probs, color=colors, alpha=0.7)
|
108 |
+
ax.set_ylim(0, 1)
|
109 |
+
for i, v in enumerate(probs):
|
110 |
+
ax.text(i, v+0.02, f"{v:.3f}", ha='center', color='black', fontsize=11)
|
111 |
+
|
112 |
+
ax.set_title("Predicted Probabilities")
|
113 |
+
ax.set_ylabel("Probability")
|
114 |
+
plt.tight_layout()
|
115 |
+
return fig
|
|
|
|
|
116 |
|
117 |
+
def create_frequency_sigma_plot(important_kmers, title):
|
118 |
"""
|
119 |
+
Creates a bar plot of the top k-mers (by importance) showing
|
120 |
+
frequency (%) and σ from mean.
|
|
|
|
|
|
|
121 |
"""
|
122 |
+
# Sort by absolute impact
|
123 |
+
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
124 |
+
kmers = [k["kmer"] for k in sorted_kmers]
|
125 |
+
frequencies = [k["occurrence"] for k in sorted_kmers] # in %
|
126 |
+
sigmas = [k["sigma"] for k in sorted_kmers]
|
127 |
+
directions = [k["direction"] for k in sorted_kmers]
|
128 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
129 |
x = np.arange(len(kmers))
|
130 |
width = 0.4
|
131 |
|
132 |
+
fig, ax_bar = plt.subplots(figsize=(10, 5))
|
133 |
|
134 |
+
# Bar for frequency
|
135 |
+
bars_freq = ax_bar.bar(
|
136 |
+
x - width/2, frequencies, width, alpha=0.7,
|
137 |
+
color=["green" if d=="human" else "red" for d in directions],
|
138 |
+
label="Frequency (%)"
|
139 |
+
)
|
140 |
+
ax_bar.set_ylabel("Frequency (%)")
|
141 |
+
ax_bar.set_ylim(0, max(frequencies) * 1.2 if len(frequencies) > 0 else 1)
|
142 |
+
|
143 |
+
# Twin axis for σ
|
144 |
+
ax_bar_twin = ax_bar.twinx()
|
145 |
+
bars_sigma = ax_bar_twin.bar(
|
146 |
+
x + width/2, sigmas, width, alpha=0.5, color="gray", label="σ from Mean"
|
147 |
+
)
|
148 |
+
ax_bar_twin.set_ylabel("Standard Deviations (σ)")
|
149 |
+
|
150 |
+
ax_bar.set_title(f"Frequency & σ from Mean for Top k-mers — {title}")
|
151 |
+
ax_bar.set_xticks(x)
|
152 |
+
ax_bar.set_xticklabels(kmers, rotation=45, ha='right')
|
153 |
+
|
154 |
+
# Combined legend
|
155 |
+
lines1, labels1 = ax_bar.get_legend_handles_labels()
|
156 |
+
lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
|
157 |
+
ax_bar.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
|
158 |
|
159 |
+
plt.tight_layout()
|
160 |
+
return fig
|
|
|
161 |
|
162 |
+
def create_importance_bar_plot(important_kmers, title):
|
163 |
+
"""
|
164 |
+
Create a simple bar chart showing the absolute gradient magnitude
|
165 |
+
for the top k-mers, sorted descending.
|
166 |
+
"""
|
167 |
+
sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
|
168 |
+
kmers = [k['kmer'] for k in sorted_kmers]
|
169 |
+
impacts = [k['impact'] for k in sorted_kmers]
|
170 |
+
directions = [k["direction"] for k in sorted_kmers]
|
171 |
+
|
172 |
+
x = np.arange(len(kmers))
|
173 |
+
|
174 |
+
fig, ax = plt.subplots(figsize=(10, 5))
|
175 |
+
bar_colors = ["green" if d=="human" else "red" for d in directions]
|
176 |
+
|
177 |
+
ax.bar(x, impacts, color=bar_colors, alpha=0.7, edgecolor='black')
|
178 |
ax.set_xticks(x)
|
179 |
ax.set_xticklabels(kmers, rotation=45, ha='right')
|
180 |
+
ax.set_title(f"Absolute Feature Importance (Top k-mers) — {title}")
|
181 |
+
ax.set_ylabel("Gradient Magnitude")
|
182 |
+
ax.grid(axis="y", alpha=0.3)
|
|
|
|
|
183 |
|
184 |
plt.tight_layout()
|
185 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
+
###############################################################################
|
188 |
+
# SHAP Beeswarm
|
189 |
+
###############################################################################
|
190 |
+
def create_shap_beeswarm_plot(
|
191 |
+
model,
|
192 |
+
input_vector: np.ndarray,
|
193 |
+
background_data: np.ndarray,
|
194 |
+
feature_names: list
|
195 |
+
):
|
196 |
+
"""
|
197 |
+
Creates a SHAP beeswarm plot using KernelExplainer for the given model and data.
|
198 |
+
|
199 |
+
Parameters
|
200 |
+
----------
|
201 |
+
model : nn.Module
|
202 |
+
Trained PyTorch model (binary classifier).
|
203 |
+
input_vector : np.ndarray
|
204 |
+
The 1-sample input (or multiple samples) we want SHAP values for.
|
205 |
+
background_data : np.ndarray
|
206 |
+
Background samples for KernelExplainer. Should have shape (N, #features).
|
207 |
+
feature_names : list
|
208 |
+
Names for each feature (k-mers).
|
209 |
+
|
210 |
+
Returns
|
211 |
+
-------
|
212 |
+
fig : matplotlib Figure
|
213 |
+
Beeswarm plot figure.
|
214 |
+
"""
|
215 |
+
|
216 |
+
# We'll define a prediction function that shap can call
|
217 |
+
# The model outputs logits for shape [N, 2]
|
218 |
+
# We want the raw outputs for each class. SHAP will handle the link function if needed.
|
219 |
+
def predict_fn(data):
|
220 |
+
"""
|
221 |
+
data: shape (N, #features)
|
222 |
+
returns: shape (N, 2) for 2-class logits
|
223 |
+
"""
|
224 |
+
with torch.no_grad():
|
225 |
+
x = torch.FloatTensor(data)
|
226 |
+
logits = model(x)
|
227 |
+
return logits.detach().cpu().numpy()
|
228 |
+
|
229 |
+
# Create KernelExplainer
|
230 |
+
explainer = shap.KernelExplainer(
|
231 |
+
model=predict_fn,
|
232 |
+
data=background_data
|
233 |
+
)
|
234 |
+
|
235 |
+
# Compute SHAP values
|
236 |
+
# For a 2-class model, shap_values is a list of length 2 => [class0 array, class1 array]
|
237 |
+
# Each array is shape (N, #features).
|
238 |
+
shap_values = explainer.shap_values(input_vector)
|
239 |
+
|
240 |
+
# We’ll produce a beeswarm for the 'human' class (class index=1).
|
241 |
+
# If we have only 1 sample, the beeswarm won't be too interesting, but let's do it anyway.
|
242 |
+
class_idx = 1 # 'human'
|
243 |
+
|
244 |
+
# If we only have one sample, place it in an array for shap summary plotting:
|
245 |
+
# We can do shap_values[class_idx].shape => (1, #features) for a single sample
|
246 |
+
# Beeswarm typically expects multiple samples. We'll plot anyway.
|
247 |
+
shap.plots.beeswarm(
|
248 |
+
shap_values[class_idx],
|
249 |
+
feature_names=feature_names,
|
250 |
+
show=False
|
251 |
+
)
|
252 |
+
|
253 |
+
fig = plt.gcf()
|
254 |
+
fig.set_size_inches(8, 6)
|
255 |
+
plt.title("SHAP Beeswarm Plot (Class: Human)")
|
256 |
+
|
257 |
+
plt.tight_layout()
|
258 |
+
return fig
|
259 |
|
260 |
###############################################################################
|
261 |
+
# Prediction Function
|
262 |
###############################################################################
|
263 |
+
def predict(file_obj):
|
264 |
"""
|
265 |
+
Main function for Gradio:
|
266 |
+
1. Reads the uploaded FASTA file or text.
|
267 |
+
2. Loads the model and scaler.
|
268 |
+
3. Generates predictions, probabilities, and top k-mers.
|
269 |
+
4. Creates multiple outputs:
|
270 |
+
- Text summary (Markdown)
|
271 |
+
- Probability Bar Plot
|
272 |
+
- SHAP Beeswarm Plot
|
273 |
+
- Frequency & σ Plot
|
274 |
+
- Absolute Feature Importance Bar Plot
|
275 |
"""
|
276 |
+
# 0. Basic file read
|
277 |
+
if file_obj is None:
|
278 |
+
return (
|
279 |
+
"Please upload a FASTA file.",
|
280 |
+
None,
|
281 |
+
None,
|
282 |
+
None,
|
283 |
+
None
|
284 |
+
)
|
285 |
+
|
286 |
try:
|
287 |
if isinstance(file_obj, str):
|
288 |
text = file_obj
|
289 |
else:
|
290 |
+
text = file_obj.decode('utf-8')
|
291 |
+
except Exception as e:
|
292 |
+
return (
|
293 |
+
f"Error reading file: {str(e)}",
|
294 |
+
None,
|
295 |
+
None,
|
296 |
+
None,
|
297 |
+
None
|
298 |
+
)
|
299 |
|
300 |
+
# 1. Parse FASTA
|
301 |
sequences = parse_fasta(text)
|
302 |
+
if len(sequences) == 0:
|
303 |
+
return (
|
304 |
+
"No valid FASTA sequences found. Please check your input.",
|
305 |
+
None,
|
306 |
+
None,
|
307 |
+
None,
|
308 |
+
None
|
309 |
+
)
|
310 |
+
header, seq = sequences[0] # We'll classify only the first sequence
|
311 |
|
312 |
+
# 2. Prepare model, scaler, and input
|
313 |
k = 4
|
|
|
|
|
|
|
314 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
315 |
try:
|
316 |
+
raw_freq_vector = sequence_to_kmer_vector(seq, k=k)
|
317 |
+
|
318 |
+
# Load model & scaler
|
319 |
+
model = VirusClassifier(input_shape=4**k).to(device)
|
320 |
+
state_dict = torch.load("model.pt", map_location=device)
|
321 |
model.load_state_dict(state_dict)
|
|
|
322 |
scaler = joblib.load("scaler.pkl")
|
323 |
+
model.eval()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
|
325 |
+
scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
|
326 |
+
X_tensor = torch.FloatTensor(scaled_vector).to(device)
|
327 |
|
328 |
+
# 3. Predict
|
329 |
+
with torch.no_grad():
|
330 |
+
logits = model(X_tensor)
|
331 |
+
probs = torch.softmax(logits, dim=1)
|
332 |
+
human_prob = float(probs[0][1])
|
333 |
+
non_human_prob = float(probs[0][0])
|
334 |
+
pred_label = "human" if human_prob >= non_human_prob else "non-human"
|
335 |
+
confidence = float(max(probs[0]))
|
336 |
+
|
337 |
+
# 4. Gradient-based feature importance
|
338 |
+
importance, hum_prob_grad = model.get_feature_importance(X_tensor)
|
339 |
+
importances = importance[0].cpu().numpy() # shape: (#features,)
|
340 |
+
abs_importances = np.abs(importances)
|
341 |
+
|
342 |
+
# 5. Gather k-mer strings
|
343 |
+
kmers_list = [''.join(p) for p in product("ACGT", repeat=k)]
|
344 |
+
# top 10 by absolute importance
|
345 |
+
top_k = 10
|
346 |
+
top_idxs = np.argsort(abs_importances)[-top_k:][::-1]
|
347 |
+
important_kmers = []
|
348 |
+
for idx in top_idxs:
|
349 |
+
direction = "human" if importances[idx] > 0 else "non-human"
|
350 |
+
freq_percent = float(raw_freq_vector[idx] * 100.0)
|
351 |
+
sigma_val = float(scaled_vector[0][idx]) # scaled / standardized val
|
352 |
+
important_kmers.append({
|
353 |
+
'kmer': kmers_list[idx],
|
354 |
+
'idx': idx,
|
355 |
+
'impact': abs_importances[idx],
|
356 |
+
'direction': direction,
|
357 |
+
'occurrence': freq_percent,
|
358 |
+
'sigma': sigma_val
|
359 |
+
})
|
360 |
+
|
361 |
+
# 6. Generate text summary
|
362 |
+
text_summary = (
|
363 |
+
f"**Sequence Header**: {header}\n\n"
|
364 |
+
f"**Predicted Label**: {pred_label}\n"
|
365 |
+
f"**Confidence**: {confidence:.4f}\n\n"
|
366 |
+
f"**Human Probability**: {human_prob:.4f}\n"
|
367 |
+
f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
|
368 |
+
"### Most Influential k-mers:\n"
|
369 |
+
)
|
370 |
+
for km in important_kmers:
|
371 |
+
direction_text = f"(pushes toward {km['direction']})"
|
372 |
+
freq_text = f"{km['occurrence']:.2f}%"
|
373 |
+
sigma_text = (
|
374 |
+
f"{abs(km['sigma']):.2f}σ "
|
375 |
+
+ ("above" if km['sigma'] > 0 else "below")
|
376 |
+
+ " mean"
|
377 |
+
)
|
378 |
+
text_summary += (
|
379 |
+
f"- **{km['kmer']}**: impact={km['impact']:.4f}, {direction_text}, "
|
380 |
+
f"occurrence={freq_text}, ({sigma_text})\n"
|
381 |
+
)
|
382 |
+
|
383 |
+
# 7. Probability Bar Plot
|
384 |
+
fig_prob = create_probability_bar_plot(human_prob, non_human_prob)
|
385 |
+
buf_prob = io.BytesIO()
|
386 |
+
fig_prob.savefig(buf_prob, format='png', bbox_inches='tight', dpi=120)
|
387 |
+
buf_prob.seek(0)
|
388 |
+
prob_img = Image.open(buf_prob)
|
389 |
+
plt.close(fig_prob)
|
390 |
+
|
391 |
+
# 8. SHAP Beeswarm Plot
|
392 |
+
# We need some background data for KernelExplainer. Let's create a small random sample
|
393 |
+
# or sample from the scaled_vector itself in a repeated manner. Real usage: choose a valid background set.
|
394 |
+
background_size = 5 # keep small for speed
|
395 |
+
# We'll pick random sequences from normal(0,1) or from scaled_vector repeated
|
396 |
+
background_data = []
|
397 |
+
for _ in range(background_size):
|
398 |
+
# Option A: random small variations around scaled_vector
|
399 |
+
# new_sample = scaled_vector[0] + np.random.normal(0, 0.5, size=scaled_vector.shape[1])
|
400 |
+
# Option B: just clone the same scaled vector multiple times
|
401 |
+
new_sample = scaled_vector[0]
|
402 |
+
background_data.append(new_sample)
|
403 |
+
background_data = np.stack(background_data, axis=0) # shape (5, #features)
|
404 |
+
|
405 |
+
fig_bee = create_shap_beeswarm_plot(
|
406 |
+
model=model,
|
407 |
+
input_vector=scaled_vector, # our single sample
|
408 |
+
background_data=background_data, # background for KernelExplainer
|
409 |
+
feature_names=kmers_list
|
410 |
+
)
|
411 |
+
buf_bee = io.BytesIO()
|
412 |
+
fig_bee.savefig(buf_bee, format='png', bbox_inches='tight', dpi=120)
|
413 |
+
buf_bee.seek(0)
|
414 |
+
bee_img = Image.open(buf_bee)
|
415 |
+
plt.close(fig_bee)
|
416 |
+
|
417 |
+
# 9. Frequency & σ Plot
|
418 |
+
fig_freq = create_frequency_sigma_plot(important_kmers, header)
|
419 |
+
buf_freq = io.BytesIO()
|
420 |
+
fig_freq.savefig(buf_freq, format='png', bbox_inches='tight', dpi=120)
|
421 |
+
buf_freq.seek(0)
|
422 |
+
freq_img = Image.open(buf_freq)
|
423 |
+
plt.close(fig_freq)
|
424 |
+
|
425 |
+
# 10. Absolute Feature Importance Bar Plot
|
426 |
+
fig_imp = create_importance_bar_plot(important_kmers, header)
|
427 |
+
buf_imp = io.BytesIO()
|
428 |
+
fig_imp.savefig(buf_imp, format='png', bbox_inches='tight', dpi=120)
|
429 |
+
buf_imp.seek(0)
|
430 |
+
imp_img = Image.open(buf_imp)
|
431 |
+
plt.close(fig_imp)
|
432 |
+
|
433 |
+
return text_summary, prob_img, bee_img, freq_img, imp_img
|
434 |
+
|
435 |
+
except Exception as e:
|
436 |
+
return (
|
437 |
+
f"Error during prediction or visualization: {str(e)}",
|
438 |
+
None,
|
439 |
+
None,
|
440 |
+
None,
|
441 |
+
None
|
442 |
+
)
|
443 |
|
444 |
|
445 |
###############################################################################
|
446 |
# Gradio Interface
|
447 |
###############################################################################
|
448 |
+
with gr.Blocks(title="Advanced Virus Host Classifier with SHAP Beeswarm") as demo:
|
449 |
+
gr.Markdown(
|
450 |
+
"""
|
451 |
+
# Advanced Virus Host Classifier (SHAP Beeswarm Edition)
|
452 |
+
|
453 |
+
**Upload a FASTA file** containing a single nucleotide sequence.
|
454 |
+
The model will predict whether this sequence is **human** or **non-human**,
|
455 |
+
provide a confidence score, and highlight the most influential k-mers.
|
456 |
+
We also produce a **SHAP beeswarm** plot for the features.
|
457 |
+
|
458 |
+
---
|
459 |
+
**Note**: Beeswarm plots are usually most insightful with multiple samples.
|
460 |
+
Here, we demonstrate usage with a single sample plus a small synthetic background.
|
461 |
+
"""
|
462 |
+
)
|
463 |
+
|
464 |
+
with gr.Row():
|
465 |
+
file_in = gr.File(label="Upload FASTA", type="binary")
|
466 |
+
btn = gr.Button("Run Prediction")
|
467 |
|
468 |
+
# We will create multiple tabs for our outputs
|
469 |
with gr.Tabs():
|
470 |
+
with gr.Tab("Prediction Results"):
|
471 |
md_out = gr.Markdown()
|
472 |
+
with gr.Tab("Probability Plot"):
|
473 |
+
prob_out = gr.Image()
|
474 |
+
with gr.Tab("SHAP Beeswarm Plot"):
|
475 |
+
bee_out = gr.Image()
|
476 |
+
with gr.Tab("Frequency & σ Plot"):
|
477 |
+
freq_out = gr.Image()
|
478 |
+
with gr.Tab("Importance Bar Plot"):
|
479 |
+
imp_out = gr.Image()
|
480 |
+
|
481 |
+
# Link the button
|
482 |
+
btn.click(
|
483 |
+
fn=predict,
|
484 |
+
inputs=[file_in],
|
485 |
+
outputs=[md_out, prob_out, bee_out, freq_out, imp_out]
|
486 |
)
|
487 |
|
|
|
488 |
if __name__ == "__main__":
|
489 |
+
# By default, share=False. You can set share=True for external access.
|
490 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|