hiyata commited on
Commit
7e19501
·
verified ·
1 Parent(s): afbf1c6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -355
app.py CHANGED
@@ -5,13 +5,12 @@ import numpy as np
5
  from itertools import product
6
  import torch.nn as nn
7
  import matplotlib
8
- matplotlib.use("Agg") # In case we're running in a no-display environment
9
  import matplotlib.pyplot as plt
10
  import io
11
  from PIL import Image
12
  import shap
13
 
14
-
15
  ###############################################################################
16
  # Model Definition
17
  ###############################################################################
@@ -41,22 +40,17 @@ class VirusClassifier(nn.Module):
41
  ###############################################################################
42
  class TorchModelWrapper:
43
  """
44
- A simple callable that takes a PyTorch model and device,
45
- allowing SHAP to pass in NumPy arrays. We convert them
46
- to torch tensors, run the model, and return NumPy outputs.
47
  """
48
  def __init__(self, model: nn.Module, device='cpu'):
49
  self.model = model
50
  self.device = device
51
 
52
  def __call__(self, x_np: np.ndarray):
53
- """
54
- x_np: shape=(batch_size, num_features) as a numpy array
55
- Returns: numpy array of shape=(batch_size, num_outputs)
56
- """
57
  x_torch = torch.from_numpy(x_np).float().to(self.device)
58
  with torch.no_grad():
59
- out = self.model(x_torch).cpu().numpy()
60
  return out
61
 
62
 
@@ -65,8 +59,8 @@ class TorchModelWrapper:
65
  ###############################################################################
66
  def parse_fasta(text):
67
  """
68
- Parses text input in FASTA format into a list of (header, sequence).
69
- Handles multiple sequences if present.
70
  """
71
  sequences = []
72
  current_header = None
@@ -89,8 +83,7 @@ def parse_fasta(text):
89
 
90
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
91
  """
92
- Convert a single nucleotide sequence to a k-mer frequency vector
93
- of length 4^k (e.g., for k=4, length=256).
94
  """
95
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
96
  kmer_dict = {km: i for i, km in enumerate(kmers)}
@@ -103,424 +96,223 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
103
 
104
  total_kmers = len(sequence) - k + 1
105
  if total_kmers > 0:
106
- vec = vec / total_kmers # normalize frequencies
107
 
108
  return vec
109
 
110
-
111
  ###############################################################################
112
- # Visualization Helpers
113
  ###############################################################################
114
- def create_freq_sigma_plot(
115
- single_shap_values: np.ndarray,
116
- raw_freq_vector: np.ndarray,
117
- scaled_vector: np.ndarray,
118
- kmer_list,
119
- title: str
120
- ):
121
  """
122
- Creates a bar plot showing top-10 k-mers (by absolute SHAP value),
123
- with frequency (%) and sigma from mean on a twin-axis.
124
-
125
- single_shap_values: shape=(256,) SHAP values for the "human" class
126
- raw_freq_vector: shape=(256,) original frequencies for this sample
127
- scaled_vector: shape=(256,) scaled (Z-score) values for this sample
128
- kmer_list: list of length=256 of all k-mers
129
  """
130
- # Identify the top 10 k-mers by absolute shap
131
- abs_vals = np.abs(single_shap_values) # shape=(256,)
132
- top_k = 10
133
- top_indices = np.argsort(abs_vals)[-top_k:][::-1] # indices of largest -> smallest
 
 
 
134
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  top_data = []
 
136
  for idx in top_indices:
137
- idx_int = int(idx) # ensure integer
138
  top_data.append({
139
- "kmer": kmer_list[idx_int],
140
- "shap": single_shap_values[idx_int],
141
- "abs_shap": abs_vals[idx_int],
142
- "frequency": raw_freq_vector[idx_int] * 100.0, # percentage
143
- "sigma": scaled_vector[idx_int]
144
  })
145
 
146
- # Sort top_data by abs_shap descending
147
  top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
148
 
149
- # Prepare for plotting
150
- kmers = [d["kmer"] for d in top_data]
151
- freqs = [d["frequency"] for d in top_data]
152
- sigmas = [d["sigma"] for d in top_data]
153
- # color by sign (positive=green => pushes "human", negative=red => pushes "non-human")
154
- colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
155
 
156
  x = np.arange(len(kmers))
157
  width = 0.4
158
 
159
- fig, ax = plt.subplots(figsize=(8, 5))
160
- # Frequency
161
- ax.bar(
162
- x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)"
163
- )
164
- ax.set_ylabel("Frequency (%)", color='black')
165
- if len(freqs) > 0:
166
- ax.set_ylim(0, max(freqs)*1.2)
167
 
168
- # Twin axis for sigma
169
  ax2 = ax.twinx()
170
- ax2.bar(
171
- x + width/2, sigmas, width, color="gray", alpha=0.5, label="σ from Mean"
172
- )
173
- ax2.set_ylabel("Standard Deviations (σ)", color='black')
174
 
175
  ax.set_xticks(x)
176
  ax.set_xticklabels(kmers, rotation=45, ha='right')
177
- ax.set_title(f"Top-10 K-mers (Frequency & σ)\n{title}")
178
 
179
- # Combine legends
180
  lines1, labels1 = ax.get_legend_handles_labels()
181
  lines2, labels2 = ax2.get_legend_handles_labels()
182
  ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
183
 
184
  plt.tight_layout()
185
- return fig
 
 
 
 
 
186
 
187
 
188
  ###############################################################################
189
- # Main Inference & SHAP Logic
190
  ###############################################################################
191
- def run_classification_and_shap(file_obj):
192
  """
193
- Reads one or more FASTA sequences from file_obj or text.
194
- Returns:
195
- - Table of results (list of dicts) for each sequence
196
- - shap_values object (SHAP values for the entire batch, shape=(num_samples, 2, num_features))
197
- - array of scaled vectors
198
- - list of k-mers
199
- - error message or None
200
  """
201
- # 1. Basic read
202
- if isinstance(file_obj, str):
203
- text = file_obj
204
- else:
205
- try:
206
  text = file_obj.decode("utf-8")
207
- except Exception as e:
208
- return None, None, None, None, f"Error reading file: {str(e)}"
209
 
210
- # 2. Parse FASTA
211
  sequences = parse_fasta(text)
212
- if len(sequences) == 0:
213
- return None, None, None, None, "No valid FASTA sequences found!"
 
214
 
215
- # 3. Convert each sequence to k-mer vector
216
  k = 4
217
- all_raw_vectors = []
218
- headers = []
219
- seqs = []
220
- for (hdr, seq) in sequences:
221
- raw_vec = sequence_to_kmer_vector(seq, k=k)
222
- all_raw_vectors.append(raw_vec)
223
- headers.append(hdr)
224
- seqs.append(seq)
225
-
226
- all_raw_vectors = np.stack(all_raw_vectors, axis=0) # shape=(num_seqs, 256)
227
 
228
  # 4. Load model & scaler
 
 
229
  try:
230
- device = "cuda" if torch.cuda.is_available() else "cpu"
231
-
232
- model = VirusClassifier(input_shape=4**k).to(device)
233
- # Use weights_only=True to suppress future warnings about untrusted pickles
234
  state_dict = torch.load("model.pt", map_location=device, weights_only=True)
235
  model.load_state_dict(state_dict)
236
  model.eval()
237
-
238
  scaler = joblib.load("scaler.pkl")
239
  except Exception as e:
240
- return None, None, None, None, f"Error loading model or scaler: {str(e)}"
241
 
242
  # 5. Scale data
243
- scaled_data = scaler.transform(all_raw_vectors) # shape=(num_seqs, 256)
244
-
245
- # 6. Predictions
246
  X_tensor = torch.FloatTensor(scaled_data).to(device)
247
- with torch.no_grad():
248
- logits = model(X_tensor)
249
- # shape=(num_seqs, 2)
250
- probs = torch.softmax(logits, dim=1).cpu().numpy()
251
- preds = np.argmax(probs, axis=1) # 0 or 1
252
-
253
- results_table = []
254
- for i, (hdr, seq) in enumerate(zip(headers, seqs)):
255
- results_table.append({
256
- "header": hdr,
257
- "sequence": seq[:50] + ("..." if len(seq) > 50 else ""),
258
- "pred_label": "human" if preds[i] == 1 else "non-human",
259
- "human_prob": float(probs[i][1]),
260
- "non_human_prob": float(probs[i][0]),
261
- "confidence": float(np.max(probs[i]))
262
- })
263
-
264
- # 7. SHAP Explainer
265
- # For large data, pick a smaller background subset
266
- if scaled_data.shape[0] > 50:
267
- background_data = scaled_data[:50]
268
- else:
269
- background_data = scaled_data
270
 
 
 
 
 
 
 
 
 
 
 
 
 
271
  wrapped_model = TorchModelWrapper(model, device)
 
272
  explainer = shap.Explainer(wrapped_model, background_data)
273
- # shap_values shape=(num_samples, num_features) if single-output
274
- # but here we have 2 outputs => shape=(num_samples, 2, num_features).
275
- shap_values = explainer(scaled_data)
276
-
277
- # Prepare k-mer list
278
- kmer_list = [''.join(p) for p in product("ACGT", repeat=k)]
279
-
280
- # Return everything
281
- return (results_table, shap_values, scaled_data, kmer_list, None)
282
-
283
-
284
- ###############################################################################
285
- # Gradio Callback Functions
286
- ###############################################################################
287
- def main_predict(file_obj):
288
- """
289
- Triggered by the 'Run Classification' button in Gradio.
290
- Returns a markdown table plus states for subsequent plots.
291
- """
292
- results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj)
293
- if err:
294
- return (err, None, None, None, None)
295
-
296
- if results is None or shap_vals is None:
297
- return ("An unknown error occurred.", None, None, None, None)
298
-
299
- # Build a summary for all sequences
300
- md = "# Classification Results\n\n"
301
- md += "| # | Header | Pred Label | Confidence | Human Prob | Non-human Prob |\n"
302
- md += "|---|--------|------------|------------|------------|----------------|\n"
303
- for i, row in enumerate(results):
304
- md += (
305
- f"| {i} | {row['header']} | {row['pred_label']} | "
306
- f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n"
307
- )
308
- md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots (class=1/human)."
309
-
310
- return (md, shap_vals, scaled_data, kmer_list, results)
311
-
312
-
313
- def update_waterfall_plot(selected_index, shap_values_obj):
314
- """
315
- Build a waterfall plot for the user-selected sample, but ONLY for class=1 (human).
316
- shap_values_obj has shape=(num_samples, 2, num_features).
317
- We do shap_values_obj[selected_index, 1] => shape=(num_features,)
318
- for a single-sample single-class explanation.
319
- """
320
- if shap_values_obj is None:
321
- return None
322
-
323
- import matplotlib.pyplot as plt
324
-
325
- try:
326
- selected_index = int(selected_index)
327
- except:
328
- selected_index = 0
329
-
330
- # We only visualize class=1 ("human") SHAP values
331
- # shap_values_obj.values shape => (num_samples, 2, num_features)
332
- single_ex_values = shap_values_obj.values[selected_index, 1, :] # shape=(256,)
333
- single_ex_base = shap_values_obj.base_values[selected_index, 1] # scalar
334
- single_ex_data = shap_values_obj.data[selected_index] # shape=(256,)
335
-
336
- # Construct a shap.Explanation object for just this one sample & class
337
- single_expl = shap.Explanation(
338
- values=single_ex_values,
339
- base_values=single_ex_base,
340
- data=single_ex_data,
341
- feature_names=[f"feat_{i}" for i in range(single_ex_values.shape[0])]
342
  )
343
 
344
- shap_plots_fig = plt.figure(figsize=(8, 5))
345
- shap.plots.waterfall(single_expl, max_display=14, show=False)
346
- buf = io.BytesIO()
347
- plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
348
- buf.seek(0)
349
- wf_img = Image.open(buf)
350
- plt.close(shap_plots_fig)
351
-
352
- return wf_img
353
-
354
-
355
- def update_beeswarm_plot(shap_values_obj):
356
- """
357
- Build a beeswarm plot across all samples, but only for class=1 (human).
358
- We slice shap_values_obj to pick shap_values_obj.values[:, 1, :]
359
- => shape=(num_samples, num_features).
360
- """
361
- if shap_values_obj is None:
362
- return None
363
-
364
- import matplotlib.pyplot as plt
365
-
366
- # For multi-output, shap_values_obj.values shape => (num_samples, 2, num_features)
367
- # We'll create a new Explanation object for class=1:
368
- class1_vals = shap_values_obj.values[:, 1, :] # shape=(num_samples, num_features)
369
- class1_base = shap_values_obj.base_values[:, 1] # shape=(num_samples,)
370
- class1_data = shap_values_obj.data # shape=(num_samples, num_features)
371
-
372
- # Some versions of shap store data in a 2D array, which is fine
373
- # We'll re-wrap them in a shap.Explanation:
374
- class1_expl = shap.Explanation(
375
- values=class1_vals,
376
- base_values=class1_base,
377
- data=class1_data,
378
- feature_names=[f"feat_{i}" for i in range(class1_vals.shape[1])]
379
  )
380
 
381
- beeswarm_fig = plt.figure(figsize=(8, 5))
382
- shap.plots.beeswarm(class1_expl, show=False)
383
- buf = io.BytesIO()
384
- plt.savefig(buf, format='png', bbox_inches='tight', dpi=120)
385
- buf.seek(0)
386
- bs_img = Image.open(buf)
387
- plt.close(beeswarm_fig)
388
-
389
- return bs_img
390
-
391
-
392
- def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj):
393
- """
394
- Create the frequency & σ bar chart for the selected sequence's top-10 k-mers (by abs SHAP).
395
- Again, we'll use class=1 SHAP values only.
396
- """
397
- if shap_values_obj is None or scaled_data is None or kmer_list is None:
398
- return None
399
 
400
- import matplotlib.pyplot as plt
401
 
402
- try:
403
- selected_index = int(selected_index)
404
- except:
405
- selected_index = 0
406
 
407
- # Re-parse the FASTA to get the corresponding sequence
408
- if isinstance(file_obj, str):
409
- text = file_obj
410
- else:
411
- text = file_obj.decode('utf-8')
412
 
413
- sequences = parse_fasta(text)
414
- # If out of range, clamp to 0
415
- if selected_index >= len(sequences):
416
- selected_index = 0
417
-
418
- seq = sequences[selected_index][1]
419
- raw_vec = sequence_to_kmer_vector(seq, k=4) # shape=(256,)
420
-
421
- # SHAP for class=1 => shape=(num_samples, 2, 256)
422
- single_shap_values = shap_values_obj.values[selected_index, 1, :]
423
- freq_sigma_fig = create_freq_sigma_plot(
424
- single_shap_values,
425
- raw_freq_vector=raw_vec,
426
- scaled_vector=scaled_data[selected_index],
427
- kmer_list=kmer_list,
428
- title=f"Sample #{selected_index} — {sequences[selected_index][0]}"
429
- )
430
-
431
- buf = io.BytesIO()
432
- freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120)
433
- buf.seek(0)
434
- fs_img = Image.open(buf)
435
- plt.close(freq_sigma_fig)
436
 
437
- return fs_img
438
 
439
 
440
  ###############################################################################
441
  # Gradio Interface
442
  ###############################################################################
443
- with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo:
444
- shap.initjs() # load shap JS if needed for HTML-based plots (optional)
445
-
446
- gr.Markdown(
447
- """
448
- # **irus Host Classifier**
449
- Upload a FASTA file with one or more nucleotide sequences.
450
- This app will:
451
- 1. Predict each sequence's **host** (human vs. non-human).
452
- 2. Provide **SHAP** explanations focusing on the 'human' class (index=1).
453
- 3. Display:
454
- - A **waterfall** plot per-sequence (top features).
455
- - A **beeswarm** plot across all sequences (global summary).
456
- - A **frequency & σ** bar chart for the top-10 k-mers of any selected sequence.
457
- """
458
- )
459
 
460
- with gr.Row():
461
- file_input = gr.File(label="Upload FASTA", type="binary")
462
- run_btn = gr.Button("Run Classification")
463
-
464
- # Store intermediate results in Gradio states
465
- shap_values_state = gr.State()
466
- scaled_data_state = gr.State()
467
- kmer_list_state = gr.State()
468
- results_state = gr.State()
469
- file_data_state = gr.State()
470
 
471
  with gr.Tabs():
472
- with gr.Tab("Results Table"):
473
  md_out = gr.Markdown()
474
-
475
  with gr.Tab("SHAP Waterfall"):
476
- with gr.Row():
477
- seq_index_input = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
478
- update_wf_btn = gr.Button("Update Waterfall")
479
-
480
- wf_plot = gr.Image(label="SHAP Waterfall Plot")
481
-
482
- with gr.Tab("SHAP Beeswarm"):
483
- bs_plot = gr.Image(label="Global Beeswarm Plot", height=500)
484
-
485
- with gr.Tab("Top-10 Frequency & Sigma"):
486
- with gr.Row():
487
- seq_index_input2 = gr.Number(label="Sequence Index (0-based)", value=0, precision=0)
488
- update_fs_btn = gr.Button("Update Frequency Chart")
489
- fs_plot = gr.Image(label="Top-10 Frequency & σ Chart")
490
 
491
- # 1) Main classification
492
  run_btn.click(
493
- fn=main_predict,
494
  inputs=[file_input],
495
- outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state]
496
- )
497
- run_btn.click(
498
- fn=lambda x: x,
499
- inputs=file_input,
500
- outputs=file_data_state
501
- )
502
-
503
- # 2) Update Waterfall
504
- update_wf_btn.click(
505
- fn=update_waterfall_plot,
506
- inputs=[seq_index_input, shap_values_state],
507
- outputs=[wf_plot]
508
- )
509
-
510
- # 3) Update Beeswarm right after classification
511
- run_btn.click(
512
- fn=update_beeswarm_plot,
513
- inputs=[shap_values_state],
514
- outputs=[bs_plot]
515
- )
516
-
517
- # 4) Update Frequency & σ
518
- update_fs_btn.click(
519
- fn=update_freq_plot,
520
- inputs=[seq_index_input2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state],
521
- outputs=[fs_plot]
522
  )
523
 
 
524
  if __name__ == "__main__":
525
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
526
-
 
5
  from itertools import product
6
  import torch.nn as nn
7
  import matplotlib
8
+ matplotlib.use("Agg") # If no display environment
9
  import matplotlib.pyplot as plt
10
  import io
11
  from PIL import Image
12
  import shap
13
 
 
14
  ###############################################################################
15
  # Model Definition
16
  ###############################################################################
 
40
  ###############################################################################
41
  class TorchModelWrapper:
42
  """
43
+ A simple callable that converts incoming NumPy arrays to torch Tensors,
44
+ does a forward pass, and returns NumPy arrays. This is needed for shap.
 
45
  """
46
  def __init__(self, model: nn.Module, device='cpu'):
47
  self.model = model
48
  self.device = device
49
 
50
  def __call__(self, x_np: np.ndarray):
 
 
 
 
51
  x_torch = torch.from_numpy(x_np).float().to(self.device)
52
  with torch.no_grad():
53
+ out = self.model(x_torch).cpu().numpy() # shape=(batch,2)
54
  return out
55
 
56
 
 
59
  ###############################################################################
60
  def parse_fasta(text):
61
  """
62
+ Parse FASTA text, returning a list of (header, sequence).
63
+ We'll only use the *first* sequence in practice.
64
  """
65
  sequences = []
66
  current_header = None
 
83
 
84
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
85
  """
86
+ Convert a single sequence to a 4^k dimension k-mer frequency vector.
 
87
  """
88
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
89
  kmer_dict = {km: i for i, km in enumerate(kmers)}
 
96
 
97
  total_kmers = len(sequence) - k + 1
98
  if total_kmers > 0:
99
+ vec /= total_kmers
100
 
101
  return vec
102
 
 
103
  ###############################################################################
104
+ # Visualization
105
  ###############################################################################
106
+ def create_waterfall_plot(shap_values, base_value, data, max_display=15):
 
 
 
 
 
 
107
  """
108
+ Create a SHAP waterfall plot for a single sample's single class
109
+ (shap_values: shape=(num_features,))
110
+ (base_value: scalar)
111
+ (data: original input data for that sample, shape=(num_features,))
 
 
 
112
  """
113
+ # Build a shap Explanation object
114
+ expl = shap.Explanation(
115
+ values=shap_values,
116
+ base_values=base_value,
117
+ data=data,
118
+ feature_names=[f"feat_{i}" for i in range(len(shap_values))]
119
+ )
120
 
121
+ fig = plt.figure(figsize=(6, 4), dpi=75)
122
+ shap.plots.waterfall(expl, max_display=max_display, show=False)
123
+ buf = io.BytesIO()
124
+ plt.savefig(buf, format='png', bbox_inches='tight')
125
+ buf.seek(0)
126
+ im = Image.open(buf)
127
+ plt.close(fig)
128
+ return im
129
+
130
+ def create_freq_sigma_plot(shap_values, raw_freq, scaled_vec, kmer_list, title="Top-10 k-mers"):
131
+ """
132
+ Show top-10 k-mers by absolute shap value with frequency (%) & z-score.
133
+ shap_values: (256,)
134
+ raw_freq: (256,) unscaled frequency
135
+ scaled_vec: (256,) scaled frequency (z-scores)
136
+ kmer_list: list of length=256
137
+ """
138
+ abs_vals = np.abs(shap_values)
139
+ top_indices = np.argsort(abs_vals)[-10:][::-1] # top 10
140
  top_data = []
141
+
142
  for idx in top_indices:
143
+ idx = int(idx) # ensure it's an integer
144
  top_data.append({
145
+ "kmer": kmer_list[idx],
146
+ "shap": shap_values[idx],
147
+ "abs_shap": abs_vals[idx],
148
+ "freq": raw_freq[idx] * 100.0,
149
+ "sigma": scaled_vec[idx]
150
  })
151
 
152
+ # Sort again by abs_shap desc
153
  top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
154
 
155
+ kmers = [d["kmer"] for d in top_data]
156
+ freqs = [d["freq"] for d in top_data]
157
+ sigmas = [d["sigma"] for d in top_data]
158
+ colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
 
 
159
 
160
  x = np.arange(len(kmers))
161
  width = 0.4
162
 
163
+ fig, ax = plt.subplots(figsize=(6, 4), dpi=75)
164
+
165
+ ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)")
166
+ ax.set_ylabel("Frequency (%)")
167
+ if freqs:
168
+ ax.set_ylim(0, max(freqs)*1.25)
 
 
169
 
 
170
  ax2 = ax.twinx()
171
+ ax2.bar(x + width/2, sigmas, width, color="gray", alpha=0.5, label="Z-score")
172
+ ax2.set_ylabel("Standard Deviations (σ)")
 
 
173
 
174
  ax.set_xticks(x)
175
  ax.set_xticklabels(kmers, rotation=45, ha='right')
176
+ ax.set_title(title)
177
 
 
178
  lines1, labels1 = ax.get_legend_handles_labels()
179
  lines2, labels2 = ax2.get_legend_handles_labels()
180
  ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
181
 
182
  plt.tight_layout()
183
+ buf = io.BytesIO()
184
+ fig.savefig(buf, format='png', bbox_inches='tight')
185
+ buf.seek(0)
186
+ im = Image.open(buf)
187
+ plt.close(fig)
188
+ return im
189
 
190
 
191
  ###############################################################################
192
+ # Main Gradio Logic
193
  ###############################################################################
194
+ def classify_and_explain(file_obj):
195
  """
196
+ - Reads the first sequence from the uploaded FASTA
197
+ - Loads model & scaler
198
+ - Produces a single prediction + shap explanation
199
+ - Returns Markdown + Waterfall image + freq/sigma image
 
 
 
200
  """
201
+ # 1. Read text from user input
202
+ try:
203
+ if isinstance(file_obj, str):
204
+ text = file_obj
205
+ else:
206
  text = file_obj.decode("utf-8")
207
+ except:
208
+ return "Error reading file!", None, None
209
 
210
+ # 2. Parse
211
  sequences = parse_fasta(text)
212
+ if not sequences:
213
+ return "No valid FASTA sequences found!", None, None
214
+ header, seq = sequences[0] # only the first sequence
215
 
216
+ # 3. Convert to k-mer
217
  k = 4
218
+ raw_freq_vec = sequence_to_kmer_vector(seq, k=k) # shape=(256,)
 
 
 
 
 
 
 
 
 
219
 
220
  # 4. Load model & scaler
221
+ device = "cuda" if torch.cuda.is_available() else "cpu"
222
+ model = VirusClassifier(4**k).to(device)
223
  try:
 
 
 
 
224
  state_dict = torch.load("model.pt", map_location=device, weights_only=True)
225
  model.load_state_dict(state_dict)
226
  model.eval()
 
227
  scaler = joblib.load("scaler.pkl")
228
  except Exception as e:
229
+ return f"Error loading model/scaler: {str(e)}", None, None
230
 
231
  # 5. Scale data
232
+ scaled_data = scaler.transform(raw_freq_vec.reshape(1, -1)) # shape=(1, 256)
 
 
233
  X_tensor = torch.FloatTensor(scaled_data).to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
235
+ # 6. Predict
236
+ with torch.no_grad():
237
+ out = model(X_tensor) # shape=(1,2)
238
+ probs = torch.softmax(out, dim=1).cpu().numpy()[0] # shape=(2,)
239
+ pred_class = np.argmax(probs)
240
+ pred_label = "human" if pred_class == 1 else "non-human"
241
+ confidence = float(np.max(probs))
242
+ human_prob = float(probs[1])
243
+ nonhuman_prob = float(probs[0])
244
+
245
+ # 7. SHAP Explanation (single sample)
246
+ # We'll wrap the model for shap
247
  wrapped_model = TorchModelWrapper(model, device)
248
+ background_data = scaled_data # 1 sample as background (a bit silly, but simpler)
249
  explainer = shap.Explainer(wrapped_model, background_data)
250
+ shap_values = explainer(scaled_data) # shape => (1, 2, 256) for 2-class output
251
+
252
+ # We'll pick class=1's shap values
253
+ sample_shap_vals = shap_values.values[0, 1, :] # shape=(256,)
254
+ base_value = shap_values.base_values[0, 1]
255
+ sample_data = shap_values.data[0] # shape=(256,)
256
+
257
+ # 8. Create the two plots
258
+ wf_img = create_waterfall_plot(
259
+ shap_values=sample_shap_vals,
260
+ base_value=base_value,
261
+ data=sample_data,
262
+ max_display=15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
  )
264
 
265
+ # freq-sigma plot
266
+ kmers = [''.join(p) for p in product("ACGT", repeat=k)]
267
+ freq_img = create_freq_sigma_plot(
268
+ shap_values=sample_shap_vals,
269
+ raw_freq=raw_freq_vec,
270
+ scaled_vec=scaled_data[0],
271
+ kmer_list=kmers,
272
+ title=f"{header[:25]}... (Top-10 K-mers)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
  )
274
 
275
+ # 9. Markdown result
276
+ result_md = f"""
277
+ # Classification Result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ **Header**: {header}
280
 
281
+ **Predicted Label**: {pred_label}
282
+ **Confidence**: {confidence:.4f}
 
 
283
 
284
+ **Human Probability**: {human_prob:.4f}
285
+ **Non-human Probability**: {nonhuman_prob:.4f}
 
 
 
286
 
287
+ Above are the SHAP-based analyses (class=1).
288
+ """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
+ return result_md, wf_img, freq_img
291
 
292
 
293
  ###############################################################################
294
  # Gradio Interface
295
  ###############################################################################
296
+ with gr.Blocks(title="Single-Sequence Virus Host Classifier") as demo:
297
+ gr.Markdown("## Upload a FASTA file containing **one** (or more) sequences. We only use the **first**.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
 
299
+ file_input = gr.File(label="Upload FASTA", type="binary")
300
+ run_btn = gr.Button("Classify & Explain")
 
 
 
 
 
 
 
 
301
 
302
  with gr.Tabs():
303
+ with gr.Tab("Results"):
304
  md_out = gr.Markdown()
 
305
  with gr.Tab("SHAP Waterfall"):
306
+ wf_out = gr.Image(label="Waterfall Plot (class=1)")
307
+ with gr.Tab("Top-10 K-mer Plot"):
308
+ freq_out = gr.Image(label="Frequency & Sigma (class=1)")
 
 
 
 
 
 
 
 
 
 
 
309
 
 
310
  run_btn.click(
311
+ fn=classify_and_explain,
312
  inputs=[file_input],
313
+ outputs=[md_out, wf_out, freq_out]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
314
  )
315
 
316
+ # No share=True -> avoid HF Spaces warning
317
  if __name__ == "__main__":
318
+ demo.launch()