Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -5,13 +5,12 @@ import numpy as np
|
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
7 |
import matplotlib
|
8 |
-
matplotlib.use("Agg") #
|
9 |
import matplotlib.pyplot as plt
|
10 |
import io
|
11 |
from PIL import Image
|
12 |
import shap
|
13 |
|
14 |
-
|
15 |
###############################################################################
|
16 |
# Model Definition
|
17 |
###############################################################################
|
@@ -41,22 +40,17 @@ class VirusClassifier(nn.Module):
|
|
41 |
###############################################################################
|
42 |
class TorchModelWrapper:
|
43 |
"""
|
44 |
-
A simple callable that
|
45 |
-
|
46 |
-
to torch tensors, run the model, and return NumPy outputs.
|
47 |
"""
|
48 |
def __init__(self, model: nn.Module, device='cpu'):
|
49 |
self.model = model
|
50 |
self.device = device
|
51 |
|
52 |
def __call__(self, x_np: np.ndarray):
|
53 |
-
"""
|
54 |
-
x_np: shape=(batch_size, num_features) as a numpy array
|
55 |
-
Returns: numpy array of shape=(batch_size, num_outputs)
|
56 |
-
"""
|
57 |
x_torch = torch.from_numpy(x_np).float().to(self.device)
|
58 |
with torch.no_grad():
|
59 |
-
out = self.model(x_torch).cpu().numpy()
|
60 |
return out
|
61 |
|
62 |
|
@@ -65,8 +59,8 @@ class TorchModelWrapper:
|
|
65 |
###############################################################################
|
66 |
def parse_fasta(text):
|
67 |
"""
|
68 |
-
|
69 |
-
|
70 |
"""
|
71 |
sequences = []
|
72 |
current_header = None
|
@@ -89,8 +83,7 @@ def parse_fasta(text):
|
|
89 |
|
90 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
91 |
"""
|
92 |
-
Convert a single
|
93 |
-
of length 4^k (e.g., for k=4, length=256).
|
94 |
"""
|
95 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
96 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
@@ -103,424 +96,223 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
|
103 |
|
104 |
total_kmers = len(sequence) - k + 1
|
105 |
if total_kmers > 0:
|
106 |
-
vec
|
107 |
|
108 |
return vec
|
109 |
|
110 |
-
|
111 |
###############################################################################
|
112 |
-
# Visualization
|
113 |
###############################################################################
|
114 |
-
def
|
115 |
-
single_shap_values: np.ndarray,
|
116 |
-
raw_freq_vector: np.ndarray,
|
117 |
-
scaled_vector: np.ndarray,
|
118 |
-
kmer_list,
|
119 |
-
title: str
|
120 |
-
):
|
121 |
"""
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
raw_freq_vector: shape=(256,) original frequencies for this sample
|
127 |
-
scaled_vector: shape=(256,) scaled (Z-score) values for this sample
|
128 |
-
kmer_list: list of length=256 of all k-mers
|
129 |
"""
|
130 |
-
#
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
134 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
top_data = []
|
|
|
136 |
for idx in top_indices:
|
137 |
-
|
138 |
top_data.append({
|
139 |
-
"kmer": kmer_list[
|
140 |
-
"shap":
|
141 |
-
"abs_shap": abs_vals[
|
142 |
-
"
|
143 |
-
"sigma":
|
144 |
})
|
145 |
|
146 |
-
# Sort
|
147 |
top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
|
148 |
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
# color by sign (positive=green => pushes "human", negative=red => pushes "non-human")
|
154 |
-
colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
|
155 |
|
156 |
x = np.arange(len(kmers))
|
157 |
width = 0.4
|
158 |
|
159 |
-
fig, ax = plt.subplots(figsize=(
|
160 |
-
|
161 |
-
ax.bar(
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
if len(freqs) > 0:
|
166 |
-
ax.set_ylim(0, max(freqs)*1.2)
|
167 |
|
168 |
-
# Twin axis for sigma
|
169 |
ax2 = ax.twinx()
|
170 |
-
ax2.bar(
|
171 |
-
|
172 |
-
)
|
173 |
-
ax2.set_ylabel("Standard Deviations (σ)", color='black')
|
174 |
|
175 |
ax.set_xticks(x)
|
176 |
ax.set_xticklabels(kmers, rotation=45, ha='right')
|
177 |
-
ax.set_title(
|
178 |
|
179 |
-
# Combine legends
|
180 |
lines1, labels1 = ax.get_legend_handles_labels()
|
181 |
lines2, labels2 = ax2.get_legend_handles_labels()
|
182 |
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
|
183 |
|
184 |
plt.tight_layout()
|
185 |
-
|
|
|
|
|
|
|
|
|
|
|
186 |
|
187 |
|
188 |
###############################################################################
|
189 |
-
# Main
|
190 |
###############################################################################
|
191 |
-
def
|
192 |
"""
|
193 |
-
Reads
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
- array of scaled vectors
|
198 |
-
- list of k-mers
|
199 |
-
- error message or None
|
200 |
"""
|
201 |
-
# 1.
|
202 |
-
|
203 |
-
|
204 |
-
|
205 |
-
|
206 |
text = file_obj.decode("utf-8")
|
207 |
-
|
208 |
-
|
209 |
|
210 |
-
# 2. Parse
|
211 |
sequences = parse_fasta(text)
|
212 |
-
if
|
213 |
-
return
|
|
|
214 |
|
215 |
-
# 3. Convert
|
216 |
k = 4
|
217 |
-
|
218 |
-
headers = []
|
219 |
-
seqs = []
|
220 |
-
for (hdr, seq) in sequences:
|
221 |
-
raw_vec = sequence_to_kmer_vector(seq, k=k)
|
222 |
-
all_raw_vectors.append(raw_vec)
|
223 |
-
headers.append(hdr)
|
224 |
-
seqs.append(seq)
|
225 |
-
|
226 |
-
all_raw_vectors = np.stack(all_raw_vectors, axis=0) # shape=(num_seqs, 256)
|
227 |
|
228 |
# 4. Load model & scaler
|
|
|
|
|
229 |
try:
|
230 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
231 |
-
|
232 |
-
model = VirusClassifier(input_shape=4**k).to(device)
|
233 |
-
# Use weights_only=True to suppress future warnings about untrusted pickles
|
234 |
state_dict = torch.load("model.pt", map_location=device, weights_only=True)
|
235 |
model.load_state_dict(state_dict)
|
236 |
model.eval()
|
237 |
-
|
238 |
scaler = joblib.load("scaler.pkl")
|
239 |
except Exception as e:
|
240 |
-
return
|
241 |
|
242 |
# 5. Scale data
|
243 |
-
scaled_data = scaler.transform(
|
244 |
-
|
245 |
-
# 6. Predictions
|
246 |
X_tensor = torch.FloatTensor(scaled_data).to(device)
|
247 |
-
with torch.no_grad():
|
248 |
-
logits = model(X_tensor)
|
249 |
-
# shape=(num_seqs, 2)
|
250 |
-
probs = torch.softmax(logits, dim=1).cpu().numpy()
|
251 |
-
preds = np.argmax(probs, axis=1) # 0 or 1
|
252 |
-
|
253 |
-
results_table = []
|
254 |
-
for i, (hdr, seq) in enumerate(zip(headers, seqs)):
|
255 |
-
results_table.append({
|
256 |
-
"header": hdr,
|
257 |
-
"sequence": seq[:50] + ("..." if len(seq) > 50 else ""),
|
258 |
-
"pred_label": "human" if preds[i] == 1 else "non-human",
|
259 |
-
"human_prob": float(probs[i][1]),
|
260 |
-
"non_human_prob": float(probs[i][0]),
|
261 |
-
"confidence": float(np.max(probs[i]))
|
262 |
-
})
|
263 |
-
|
264 |
-
# 7. SHAP Explainer
|
265 |
-
# For large data, pick a smaller background subset
|
266 |
-
if scaled_data.shape[0] > 50:
|
267 |
-
background_data = scaled_data[:50]
|
268 |
-
else:
|
269 |
-
background_data = scaled_data
|
270 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
271 |
wrapped_model = TorchModelWrapper(model, device)
|
|
|
272 |
explainer = shap.Explainer(wrapped_model, background_data)
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
|
279 |
-
|
280 |
-
#
|
281 |
-
|
282 |
-
|
283 |
-
|
284 |
-
|
285 |
-
|
286 |
-
###############################################################################
|
287 |
-
def main_predict(file_obj):
|
288 |
-
"""
|
289 |
-
Triggered by the 'Run Classification' button in Gradio.
|
290 |
-
Returns a markdown table plus states for subsequent plots.
|
291 |
-
"""
|
292 |
-
results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
|
293 |
-
if err:
|
294 |
-
return (err, None, None, None, None)
|
295 |
-
|
296 |
-
if results is None or shap_vals is None:
|
297 |
-
return ("An unknown error occurred.", None, None, None, None)
|
298 |
-
|
299 |
-
# Build a summary for all sequences
|
300 |
-
md = "# Classification Results\n\n"
|
301 |
-
md += "| # | Header | Pred Label | Confidence | Human Prob | Non-human Prob |\n"
|
302 |
-
md += "|---|--------|------------|------------|------------|----------------|\n"
|
303 |
-
for i, row in enumerate(results):
|
304 |
-
md += (
|
305 |
-
f"| {i} | {row['header']} | {row['pred_label']} | "
|
306 |
-
f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n"
|
307 |
-
)
|
308 |
-
md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots (class=1/human)."
|
309 |
-
|
310 |
-
return (md, shap_vals, scaled_data, kmer_list, results)
|
311 |
-
|
312 |
-
|
313 |
-
def update_waterfall_plot(selected_index, shap_values_obj):
|
314 |
-
"""
|
315 |
-
Build a waterfall plot for the user-selected sample, but ONLY for class=1 (human).
|
316 |
-
shap_values_obj has shape=(num_samples, 2, num_features).
|
317 |
-
We do shap_values_obj[selected_index, 1] => shape=(num_features,)
|
318 |
-
for a single-sample single-class explanation.
|
319 |
-
"""
|
320 |
-
if shap_values_obj is None:
|
321 |
-
return None
|
322 |
-
|
323 |
-
import matplotlib.pyplot as plt
|
324 |
-
|
325 |
-
try:
|
326 |
-
selected_index = int(selected_index)
|
327 |
-
except:
|
328 |
-
selected_index = 0
|
329 |
-
|
330 |
-
# We only visualize class=1 ("human") SHAP values
|
331 |
-
# shap_values_obj.values shape => (num_samples, 2, num_features)
|
332 |
-
single_ex_values = shap_values_obj.values[selected_index, 1, :] # shape=(256,)
|
333 |
-
single_ex_base = shap_values_obj.base_values[selected_index, 1] # scalar
|
334 |
-
single_ex_data = shap_values_obj.data[selected_index] # shape=(256,)
|
335 |
-
|
336 |
-
# Construct a shap.Explanation object for just this one sample & class
|
337 |
-
single_expl = shap.Explanation(
|
338 |
-
values=single_ex_values,
|
339 |
-
base_values=single_ex_base,
|
340 |
-
data=single_ex_data,
|
341 |
-
feature_names=[f"feat_{i}" for i in range(single_ex_values.shape[0])]
|
342 |
)
|
343 |
|
344 |
-
|
345 |
-
|
346 |
-
|
347 |
-
|
348 |
-
|
349 |
-
|
350 |
-
|
351 |
-
|
352 |
-
return wf_img
|
353 |
-
|
354 |
-
|
355 |
-
def update_beeswarm_plot(shap_values_obj):
|
356 |
-
"""
|
357 |
-
Build a beeswarm plot across all samples, but only for class=1 (human).
|
358 |
-
We slice shap_values_obj to pick shap_values_obj.values[:, 1, :]
|
359 |
-
=> shape=(num_samples, num_features).
|
360 |
-
"""
|
361 |
-
if shap_values_obj is None:
|
362 |
-
return None
|
363 |
-
|
364 |
-
import matplotlib.pyplot as plt
|
365 |
-
|
366 |
-
# For multi-output, shap_values_obj.values shape => (num_samples, 2, num_features)
|
367 |
-
# We'll create a new Explanation object for class=1:
|
368 |
-
class1_vals = shap_values_obj.values[:, 1, :] # shape=(num_samples, num_features)
|
369 |
-
class1_base = shap_values_obj.base_values[:, 1] # shape=(num_samples,)
|
370 |
-
class1_data = shap_values_obj.data # shape=(num_samples, num_features)
|
371 |
-
|
372 |
-
# Some versions of shap store data in a 2D array, which is fine
|
373 |
-
# We'll re-wrap them in a shap.Explanation:
|
374 |
-
class1_expl = shap.Explanation(
|
375 |
-
values=class1_vals,
|
376 |
-
base_values=class1_base,
|
377 |
-
data=class1_data,
|
378 |
-
feature_names=[f"feat_{i}" for i in range(class1_vals.shape[1])]
|
379 |
)
|
380 |
|
381 |
-
|
382 |
-
|
383 |
-
|
384 |
-
plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
385 |
-
buf.seek(0)
|
386 |
-
bs_img = Image.open(buf)
|
387 |
-
plt.close(beeswarm_fig)
|
388 |
-
|
389 |
-
return bs_img
|
390 |
-
|
391 |
-
|
392 |
-
def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
|
393 |
-
"""
|
394 |
-
Create the frequency & σ bar chart for the selected sequence's top-10 k-mers (by abs SHAP).
|
395 |
-
Again, we'll use class=1 SHAP values only.
|
396 |
-
"""
|
397 |
-
if shap_values_obj is None or scaled_data is None or kmer_list is None:
|
398 |
-
return None
|
399 |
|
400 |
-
|
401 |
|
402 |
-
|
403 |
-
|
404 |
-
except:
|
405 |
-
selected_index = 0
|
406 |
|
407 |
-
|
408 |
-
|
409 |
-
text = file_obj
|
410 |
-
else:
|
411 |
-
text = file_obj.decode('utf-8')
|
412 |
|
413 |
-
|
414 |
-
|
415 |
-
if selected_index >= len(sequences):
|
416 |
-
selected_index = 0
|
417 |
-
|
418 |
-
seq = sequences[selected_index][1]
|
419 |
-
raw_vec = sequence_to_kmer_vector(seq, k=4) # shape=(256,)
|
420 |
-
|
421 |
-
# SHAP for class=1 => shape=(num_samples, 2, 256)
|
422 |
-
single_shap_values = shap_values_obj.values[selected_index, 1, :]
|
423 |
-
freq_sigma_fig = create_freq_sigma_plot(
|
424 |
-
single_shap_values,
|
425 |
-
raw_freq_vector=raw_vec,
|
426 |
-
scaled_vector=scaled_data[selected_index],
|
427 |
-
kmer_list=kmer_list,
|
428 |
-
title=f"Sample #{selected_index} — {sequences[selected_index][0]}"
|
429 |
-
)
|
430 |
-
|
431 |
-
buf = io.BytesIO()
|
432 |
-
freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120)
|
433 |
-
buf.seek(0)
|
434 |
-
fs_img = Image.open(buf)
|
435 |
-
plt.close(freq_sigma_fig)
|
436 |
|
437 |
-
return
|
438 |
|
439 |
|
440 |
###############################################################################
|
441 |
# Gradio Interface
|
442 |
###############################################################################
|
443 |
-
with gr.Blocks(title="
|
444 |
-
|
445 |
-
|
446 |
-
gr.Markdown(
|
447 |
-
"""
|
448 |
-
# **irus Host Classifier**
|
449 |
-
Upload a FASTA file with one or more nucleotide sequences.
|
450 |
-
This app will:
|
451 |
-
1. Predict each sequence's **host** (human vs. non-human).
|
452 |
-
2. Provide **SHAP** explanations focusing on the 'human' class (index=1).
|
453 |
-
3. Display:
|
454 |
-
- A **waterfall** plot per-sequence (top features).
|
455 |
-
- A **beeswarm** plot across all sequences (global summary).
|
456 |
-
- A **frequency & σ** bar chart for the top-10 k-mers of any selected sequence.
|
457 |
-
"""
|
458 |
-
)
|
459 |
|
460 |
-
|
461 |
-
|
462 |
-
run_btn = gr.Button("Run Classification")
|
463 |
-
|
464 |
-
# Store intermediate results in Gradio states
|
465 |
-
shap_values_state = gr.State()
|
466 |
-
scaled_data_state = gr.State()
|
467 |
-
kmer_list_state = gr.State()
|
468 |
-
results_state = gr.State()
|
469 |
-
file_data_state = gr.State()
|
470 |
|
471 |
with gr.Tabs():
|
472 |
-
with gr.Tab("Results
|
473 |
md_out = gr.Markdown()
|
474 |
-
|
475 |
with gr.Tab("SHAP Waterfall"):
|
476 |
-
|
477 |
-
|
478 |
-
|
479 |
-
|
480 |
-
wf_plot = gr.Image(label="SHAP Waterfall Plot")
|
481 |
-
|
482 |
-
with gr.Tab("SHAP Beeswarm"):
|
483 |
-
bs_plot = gr.Image(label="Global Beeswarm Plot", height=500)
|
484 |
-
|
485 |
-
with gr.Tab("Top-10 Frequency & Sigma"):
|
486 |
-
with gr.Row():
|
487 |
-
seq_index_input2 = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
|
488 |
-
update_fs_btn = gr.Button("Update Frequency Chart")
|
489 |
-
fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
|
490 |
|
491 |
-
# 1) Main classification
|
492 |
run_btn.click(
|
493 |
-
fn=
|
494 |
inputs=[file_input],
|
495 |
-
outputs=[md_out,
|
496 |
-
)
|
497 |
-
run_btn.click(
|
498 |
-
fn=lambda x: x,
|
499 |
-
inputs=file_input,
|
500 |
-
outputs=file_data_state
|
501 |
-
)
|
502 |
-
|
503 |
-
# 2) Update Waterfall
|
504 |
-
update_wf_btn.click(
|
505 |
-
fn=update_waterfall_plot,
|
506 |
-
inputs=[seq_index_input, shap_values_state],
|
507 |
-
outputs=[wf_plot]
|
508 |
-
)
|
509 |
-
|
510 |
-
# 3) Update Beeswarm right after classification
|
511 |
-
run_btn.click(
|
512 |
-
fn=update_beeswarm_plot,
|
513 |
-
inputs=[shap_values_state],
|
514 |
-
outputs=[bs_plot]
|
515 |
-
)
|
516 |
-
|
517 |
-
# 4) Update Frequency & σ
|
518 |
-
update_fs_btn.click(
|
519 |
-
fn=update_freq_plot,
|
520 |
-
inputs=[seq_index_input2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
|
521 |
-
outputs=[fs_plot]
|
522 |
)
|
523 |
|
|
|
524 |
if __name__ == "__main__":
|
525 |
-
demo.launch(
|
526 |
-
|
|
|
5 |
from itertools import product
|
6 |
import torch.nn as nn
|
7 |
import matplotlib
|
8 |
+
matplotlib.use("Agg") # If no display environment
|
9 |
import matplotlib.pyplot as plt
|
10 |
import io
|
11 |
from PIL import Image
|
12 |
import shap
|
13 |
|
|
|
14 |
###############################################################################
|
15 |
# Model Definition
|
16 |
###############################################################################
|
|
|
40 |
###############################################################################
|
41 |
class TorchModelWrapper:
|
42 |
"""
|
43 |
+
A simple callable that converts incoming NumPy arrays to torch Tensors,
|
44 |
+
does a forward pass, and returns NumPy arrays. This is needed for shap.
|
|
|
45 |
"""
|
46 |
def __init__(self, model: nn.Module, device='cpu'):
|
47 |
self.model = model
|
48 |
self.device = device
|
49 |
|
50 |
def __call__(self, x_np: np.ndarray):
|
|
|
|
|
|
|
|
|
51 |
x_torch = torch.from_numpy(x_np).float().to(self.device)
|
52 |
with torch.no_grad():
|
53 |
+
out = self.model(x_torch).cpu().numpy() # shape=(batch,2)
|
54 |
return out
|
55 |
|
56 |
|
|
|
59 |
###############################################################################
|
60 |
def parse_fasta(text):
|
61 |
"""
|
62 |
+
Parse FASTA text, returning a list of (header, sequence).
|
63 |
+
We'll only use the *first* sequence in practice.
|
64 |
"""
|
65 |
sequences = []
|
66 |
current_header = None
|
|
|
83 |
|
84 |
def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
|
85 |
"""
|
86 |
+
Convert a single sequence to a 4^k dimension k-mer frequency vector.
|
|
|
87 |
"""
|
88 |
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
89 |
kmer_dict = {km: i for i, km in enumerate(kmers)}
|
|
|
96 |
|
97 |
total_kmers = len(sequence) - k + 1
|
98 |
if total_kmers > 0:
|
99 |
+
vec /= total_kmers
|
100 |
|
101 |
return vec
|
102 |
|
|
|
103 |
###############################################################################
|
104 |
+
# Visualization
|
105 |
###############################################################################
|
106 |
+
def create_waterfall_plot(shap_values, base_value, data, max_display=15):
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
"""
|
108 |
+
Create a SHAP waterfall plot for a single sample's single class
|
109 |
+
(shap_values: shape=(num_features,))
|
110 |
+
(base_value: scalar)
|
111 |
+
(data: original input data for that sample, shape=(num_features,))
|
|
|
|
|
|
|
112 |
"""
|
113 |
+
# Build a shap Explanation object
|
114 |
+
expl = shap.Explanation(
|
115 |
+
values=shap_values,
|
116 |
+
base_values=base_value,
|
117 |
+
data=data,
|
118 |
+
feature_names=[f"feat_{i}" for i in range(len(shap_values))]
|
119 |
+
)
|
120 |
|
121 |
+
fig = plt.figure(figsize=(6, 4), dpi=75)
|
122 |
+
shap.plots.waterfall(expl, max_display=max_display, show=False)
|
123 |
+
buf = io.BytesIO()
|
124 |
+
plt.savefig(buf, format='png', bbox_inches='tight')
|
125 |
+
buf.seek(0)
|
126 |
+
im = Image.open(buf)
|
127 |
+
plt.close(fig)
|
128 |
+
return im
|
129 |
+
|
130 |
+
def create_freq_sigma_plot(shap_values, raw_freq, scaled_vec, kmer_list, title="Top-10 k-mers"):
|
131 |
+
"""
|
132 |
+
Show top-10 k-mers by absolute shap value with frequency (%) & z-score.
|
133 |
+
shap_values: (256,)
|
134 |
+
raw_freq: (256,) unscaled frequency
|
135 |
+
scaled_vec: (256,) scaled frequency (z-scores)
|
136 |
+
kmer_list: list of length=256
|
137 |
+
"""
|
138 |
+
abs_vals = np.abs(shap_values)
|
139 |
+
top_indices = np.argsort(abs_vals)[-10:][::-1] # top 10
|
140 |
top_data = []
|
141 |
+
|
142 |
for idx in top_indices:
|
143 |
+
idx = int(idx) # ensure it's an integer
|
144 |
top_data.append({
|
145 |
+
"kmer": kmer_list[idx],
|
146 |
+
"shap": shap_values[idx],
|
147 |
+
"abs_shap": abs_vals[idx],
|
148 |
+
"freq": raw_freq[idx] * 100.0,
|
149 |
+
"sigma": scaled_vec[idx]
|
150 |
})
|
151 |
|
152 |
+
# Sort again by abs_shap desc
|
153 |
top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
|
154 |
|
155 |
+
kmers = [d["kmer"] for d in top_data]
|
156 |
+
freqs = [d["freq"] for d in top_data]
|
157 |
+
sigmas = [d["sigma"] for d in top_data]
|
158 |
+
colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
|
|
|
|
|
159 |
|
160 |
x = np.arange(len(kmers))
|
161 |
width = 0.4
|
162 |
|
163 |
+
fig, ax = plt.subplots(figsize=(6, 4), dpi=75)
|
164 |
+
|
165 |
+
ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)")
|
166 |
+
ax.set_ylabel("Frequency (%)")
|
167 |
+
if freqs:
|
168 |
+
ax.set_ylim(0, max(freqs)*1.25)
|
|
|
|
|
169 |
|
|
|
170 |
ax2 = ax.twinx()
|
171 |
+
ax2.bar(x + width/2, sigmas, width, color="gray", alpha=0.5, label="Z-score")
|
172 |
+
ax2.set_ylabel("Standard Deviations (σ)")
|
|
|
|
|
173 |
|
174 |
ax.set_xticks(x)
|
175 |
ax.set_xticklabels(kmers, rotation=45, ha='right')
|
176 |
+
ax.set_title(title)
|
177 |
|
|
|
178 |
lines1, labels1 = ax.get_legend_handles_labels()
|
179 |
lines2, labels2 = ax2.get_legend_handles_labels()
|
180 |
ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
|
181 |
|
182 |
plt.tight_layout()
|
183 |
+
buf = io.BytesIO()
|
184 |
+
fig.savefig(buf, format='png', bbox_inches='tight')
|
185 |
+
buf.seek(0)
|
186 |
+
im = Image.open(buf)
|
187 |
+
plt.close(fig)
|
188 |
+
return im
|
189 |
|
190 |
|
191 |
###############################################################################
|
192 |
+
# Main Gradio Logic
|
193 |
###############################################################################
|
194 |
+
def classify_and_explain(file_obj):
|
195 |
"""
|
196 |
+
- Reads the first sequence from the uploaded FASTA
|
197 |
+
- Loads model & scaler
|
198 |
+
- Produces a single prediction + shap explanation
|
199 |
+
- Returns Markdown + Waterfall image + freq/sigma image
|
|
|
|
|
|
|
200 |
"""
|
201 |
+
# 1. Read text from user input
|
202 |
+
try:
|
203 |
+
if isinstance(file_obj, str):
|
204 |
+
text = file_obj
|
205 |
+
else:
|
206 |
text = file_obj.decode("utf-8")
|
207 |
+
except:
|
208 |
+
return "Error reading file!", None, None
|
209 |
|
210 |
+
# 2. Parse
|
211 |
sequences = parse_fasta(text)
|
212 |
+
if not sequences:
|
213 |
+
return "No valid FASTA sequences found!", None, None
|
214 |
+
header, seq = sequences[0] # only the first sequence
|
215 |
|
216 |
+
# 3. Convert to k-mer
|
217 |
k = 4
|
218 |
+
raw_freq_vec = sequence_to_kmer_vector(seq, k=k) # shape=(256,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
219 |
|
220 |
# 4. Load model & scaler
|
221 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
222 |
+
model = VirusClassifier(4**k).to(device)
|
223 |
try:
|
|
|
|
|
|
|
|
|
224 |
state_dict = torch.load("model.pt", map_location=device, weights_only=True)
|
225 |
model.load_state_dict(state_dict)
|
226 |
model.eval()
|
|
|
227 |
scaler = joblib.load("scaler.pkl")
|
228 |
except Exception as e:
|
229 |
+
return f"Error loading model/scaler: {str(e)}", None, None
|
230 |
|
231 |
# 5. Scale data
|
232 |
+
scaled_data = scaler.transform(raw_freq_vec.reshape(1, -1)) # shape=(1, 256)
|
|
|
|
|
233 |
X_tensor = torch.FloatTensor(scaled_data).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
|
235 |
+
# 6. Predict
|
236 |
+
with torch.no_grad():
|
237 |
+
out = model(X_tensor) # shape=(1,2)
|
238 |
+
probs = torch.softmax(out, dim=1).cpu().numpy()[0] # shape=(2,)
|
239 |
+
pred_class = np.argmax(probs)
|
240 |
+
pred_label = "human" if pred_class == 1 else "non-human"
|
241 |
+
confidence = float(np.max(probs))
|
242 |
+
human_prob = float(probs[1])
|
243 |
+
nonhuman_prob = float(probs[0])
|
244 |
+
|
245 |
+
# 7. SHAP Explanation (single sample)
|
246 |
+
# We'll wrap the model for shap
|
247 |
wrapped_model = TorchModelWrapper(model, device)
|
248 |
+
background_data = scaled_data # 1 sample as background (a bit silly, but simpler)
|
249 |
explainer = shap.Explainer(wrapped_model, background_data)
|
250 |
+
shap_values = explainer(scaled_data) # shape => (1, 2, 256) for 2-class output
|
251 |
+
|
252 |
+
# We'll pick class=1's shap values
|
253 |
+
sample_shap_vals = shap_values.values[0, 1, :] # shape=(256,)
|
254 |
+
base_value = shap_values.base_values[0, 1]
|
255 |
+
sample_data = shap_values.data[0] # shape=(256,)
|
256 |
+
|
257 |
+
# 8. Create the two plots
|
258 |
+
wf_img = create_waterfall_plot(
|
259 |
+
shap_values=sample_shap_vals,
|
260 |
+
base_value=base_value,
|
261 |
+
data=sample_data,
|
262 |
+
max_display=15
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
263 |
)
|
264 |
|
265 |
+
# freq-sigma plot
|
266 |
+
kmers = [''.join(p) for p in product("ACGT", repeat=k)]
|
267 |
+
freq_img = create_freq_sigma_plot(
|
268 |
+
shap_values=sample_shap_vals,
|
269 |
+
raw_freq=raw_freq_vec,
|
270 |
+
scaled_vec=scaled_data[0],
|
271 |
+
kmer_list=kmers,
|
272 |
+
title=f"{header[:25]}... (Top-10 K-mers)"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
)
|
274 |
|
275 |
+
# 9. Markdown result
|
276 |
+
result_md = f"""
|
277 |
+
# Classification Result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
|
279 |
+
**Header**: {header}
|
280 |
|
281 |
+
**Predicted Label**: {pred_label}
|
282 |
+
**Confidence**: {confidence:.4f}
|
|
|
|
|
283 |
|
284 |
+
**Human Probability**: {human_prob:.4f}
|
285 |
+
**Non-human Probability**: {nonhuman_prob:.4f}
|
|
|
|
|
|
|
286 |
|
287 |
+
Above are the SHAP-based analyses (class=1).
|
288 |
+
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
+
return result_md, wf_img, freq_img
|
291 |
|
292 |
|
293 |
###############################################################################
|
294 |
# Gradio Interface
|
295 |
###############################################################################
|
296 |
+
with gr.Blocks(title="Single-Sequence Virus Host Classifier") as demo:
|
297 |
+
gr.Markdown("## Upload a FASTA file containing **one** (or more) sequences. We only use the **first**.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
298 |
|
299 |
+
file_input = gr.File(label="Upload FASTA", type="binary")
|
300 |
+
run_btn = gr.Button("Classify & Explain")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
|
302 |
with gr.Tabs():
|
303 |
+
with gr.Tab("Results"):
|
304 |
md_out = gr.Markdown()
|
|
|
305 |
with gr.Tab("SHAP Waterfall"):
|
306 |
+
wf_out = gr.Image(label="Waterfall Plot (class=1)")
|
307 |
+
with gr.Tab("Top-10 K-mer Plot"):
|
308 |
+
freq_out = gr.Image(label="Frequency & Sigma (class=1)")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
|
|
|
310 |
run_btn.click(
|
311 |
+
fn=classify_and_explain,
|
312 |
inputs=[file_input],
|
313 |
+
outputs=[md_out, wf_out, freq_out]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
314 |
)
|
315 |
|
316 |
+
# No share=True -> avoid HF Spaces warning
|
317 |
if __name__ == "__main__":
|
318 |
+
demo.launch()
|
|