hiyata commited on
Commit
3b775b7
·
verified ·
1 Parent(s): 7e19501

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +370 -198
app.py CHANGED
@@ -2,14 +2,13 @@ import gradio as gr
2
  import torch
3
  import joblib
4
  import numpy as np
 
 
5
  from itertools import product
6
  import torch.nn as nn
7
- import matplotlib
8
- matplotlib.use("Agg") # If no display environment
9
  import matplotlib.pyplot as plt
10
  import io
11
  from PIL import Image
12
- import shap
13
 
14
  ###############################################################################
15
  # Model Definition
@@ -33,35 +32,30 @@ class VirusClassifier(nn.Module):
33
 
34
  def forward(self, x):
35
  return self.network(x)
36
-
37
-
38
- ###############################################################################
39
- # Torch Model Wrapper for SHAP
40
- ###############################################################################
41
- class TorchModelWrapper:
42
- """
43
- A simple callable that converts incoming NumPy arrays to torch Tensors,
44
- does a forward pass, and returns NumPy arrays. This is needed for shap.
45
- """
46
- def __init__(self, model: nn.Module, device='cpu'):
47
- self.model = model
48
- self.device = device
49
-
50
- def __call__(self, x_np: np.ndarray):
51
- x_torch = torch.from_numpy(x_np).float().to(self.device)
52
- with torch.no_grad():
53
- out = self.model(x_torch).cpu().numpy() # shape=(batch,2)
54
- return out
55
-
56
 
57
  ###############################################################################
58
  # Utility Functions
59
  ###############################################################################
60
  def parse_fasta(text):
61
- """
62
- Parse FASTA text, returning a list of (header, sequence).
63
- We'll only use the *first* sequence in practice.
64
- """
65
  sequences = []
66
  current_header = None
67
  current_sequence = []
@@ -82,9 +76,7 @@ def parse_fasta(text):
82
  return sequences
83
 
84
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
85
- """
86
- Convert a single sequence to a 4^k dimension k-mer frequency vector.
87
- """
88
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
89
  kmer_dict = {km: i for i, km in enumerate(kmers)}
90
  vec = np.zeros(len(kmers), dtype=np.float32)
@@ -96,223 +88,403 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
96
 
97
  total_kmers = len(sequence) - k + 1
98
  if total_kmers > 0:
99
- vec /= total_kmers
100
 
101
  return vec
102
 
103
  ###############################################################################
104
- # Visualization
105
  ###############################################################################
106
- def create_waterfall_plot(shap_values, base_value, data, max_display=15):
107
  """
108
- Create a SHAP waterfall plot for a single sample's single class
109
- (shap_values: shape=(num_features,))
110
- (base_value: scalar)
111
- (data: original input data for that sample, shape=(num_features,))
112
  """
113
- # Build a shap Explanation object
114
- expl = shap.Explanation(
115
- values=shap_values,
116
- base_values=base_value,
117
- data=data,
118
- feature_names=[f"feat_{i}" for i in range(len(shap_values))]
119
- )
120
-
121
- fig = plt.figure(figsize=(6, 4), dpi=75)
122
- shap.plots.waterfall(expl, max_display=max_display, show=False)
123
- buf = io.BytesIO()
124
- plt.savefig(buf, format='png', bbox_inches='tight')
125
- buf.seek(0)
126
- im = Image.open(buf)
127
- plt.close(fig)
128
- return im
129
 
130
- def create_freq_sigma_plot(shap_values, raw_freq, scaled_vec, kmer_list, title="Top-10 k-mers"):
131
  """
132
- Show top-10 k-mers by absolute shap value with frequency (%) & z-score.
133
- shap_values: (256,)
134
- raw_freq: (256,) unscaled frequency
135
- scaled_vec: (256,) scaled frequency (z-scores)
136
- kmer_list: list of length=256
137
  """
138
- abs_vals = np.abs(shap_values)
139
- top_indices = np.argsort(abs_vals)[-10:][::-1] # top 10
140
- top_data = []
 
 
 
141
 
142
- for idx in top_indices:
143
- idx = int(idx) # ensure it's an integer
144
- top_data.append({
145
- "kmer": kmer_list[idx],
146
- "shap": shap_values[idx],
147
- "abs_shap": abs_vals[idx],
148
- "freq": raw_freq[idx] * 100.0,
149
- "sigma": scaled_vec[idx]
150
- })
151
-
152
- # Sort again by abs_shap desc
153
- top_data.sort(key=lambda x: x["abs_shap"], reverse=True)
154
-
155
- kmers = [d["kmer"] for d in top_data]
156
- freqs = [d["freq"] for d in top_data]
157
- sigmas = [d["sigma"] for d in top_data]
158
- colors = ["green" if d["shap"] >= 0 else "red" for d in top_data]
159
-
160
  x = np.arange(len(kmers))
161
  width = 0.4
162
 
163
- fig, ax = plt.subplots(figsize=(6, 4), dpi=75)
164
 
165
- ax.bar(x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)")
166
- ax.set_ylabel("Frequency (%)")
167
- if freqs:
168
- ax.set_ylim(0, max(freqs)*1.25)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
169
 
170
- ax2 = ax.twinx()
171
- ax2.bar(x + width/2, sigmas, width, color="gray", alpha=0.5, label="Z-score")
172
- ax2.set_ylabel("Standard Deviations (σ)")
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  ax.set_xticks(x)
175
  ax.set_xticklabels(kmers, rotation=45, ha='right')
176
- ax.set_title(title)
177
-
178
- lines1, labels1 = ax.get_legend_handles_labels()
179
- lines2, labels2 = ax2.get_legend_handles_labels()
180
- ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right')
181
 
182
  plt.tight_layout()
183
- buf = io.BytesIO()
184
- fig.savefig(buf, format='png', bbox_inches='tight')
185
- buf.seek(0)
186
- im = Image.open(buf)
187
- plt.close(fig)
188
- return im
189
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  ###############################################################################
192
- # Main Gradio Logic
193
  ###############################################################################
194
- def classify_and_explain(file_obj):
195
  """
196
- - Reads the first sequence from the uploaded FASTA
197
- - Loads model & scaler
198
- - Produces a single prediction + shap explanation
199
- - Returns Markdown + Waterfall image + freq/sigma image
 
 
 
 
 
 
200
  """
201
- # 1. Read text from user input
 
 
 
 
 
 
 
 
 
202
  try:
203
  if isinstance(file_obj, str):
204
  text = file_obj
205
  else:
206
- text = file_obj.decode("utf-8")
207
- except:
208
- return "Error reading file!", None, None
 
 
 
 
 
 
209
 
210
- # 2. Parse
211
  sequences = parse_fasta(text)
212
- if not sequences:
213
- return "No valid FASTA sequences found!", None, None
214
- header, seq = sequences[0] # only the first sequence
 
 
 
 
 
 
215
 
216
- # 3. Convert to k-mer
217
  k = 4
218
- raw_freq_vec = sequence_to_kmer_vector(seq, k=k) # shape=(256,)
219
-
220
- # 4. Load model & scaler
221
  device = "cuda" if torch.cuda.is_available() else "cpu"
222
- model = VirusClassifier(4**k).to(device)
223
  try:
224
- state_dict = torch.load("model.pt", map_location=device, weights_only=True)
 
 
 
 
225
  model.load_state_dict(state_dict)
226
- model.eval()
227
  scaler = joblib.load("scaler.pkl")
228
- except Exception as e:
229
- return f"Error loading model/scaler: {str(e)}", None, None
230
-
231
- # 5. Scale data
232
- scaled_data = scaler.transform(raw_freq_vec.reshape(1, -1)) # shape=(1, 256)
233
- X_tensor = torch.FloatTensor(scaled_data).to(device)
234
-
235
- # 6. Predict
236
- with torch.no_grad():
237
- out = model(X_tensor) # shape=(1,2)
238
- probs = torch.softmax(out, dim=1).cpu().numpy()[0] # shape=(2,)
239
- pred_class = np.argmax(probs)
240
- pred_label = "human" if pred_class == 1 else "non-human"
241
- confidence = float(np.max(probs))
242
- human_prob = float(probs[1])
243
- nonhuman_prob = float(probs[0])
244
-
245
- # 7. SHAP Explanation (single sample)
246
- # We'll wrap the model for shap
247
- wrapped_model = TorchModelWrapper(model, device)
248
- background_data = scaled_data # 1 sample as background (a bit silly, but simpler)
249
- explainer = shap.Explainer(wrapped_model, background_data)
250
- shap_values = explainer(scaled_data) # shape => (1, 2, 256) for 2-class output
251
-
252
- # We'll pick class=1's shap values
253
- sample_shap_vals = shap_values.values[0, 1, :] # shape=(256,)
254
- base_value = shap_values.base_values[0, 1]
255
- sample_data = shap_values.data[0] # shape=(256,)
256
-
257
- # 8. Create the two plots
258
- wf_img = create_waterfall_plot(
259
- shap_values=sample_shap_vals,
260
- base_value=base_value,
261
- data=sample_data,
262
- max_display=15
263
- )
264
-
265
- # freq-sigma plot
266
- kmers = [''.join(p) for p in product("ACGT", repeat=k)]
267
- freq_img = create_freq_sigma_plot(
268
- shap_values=sample_shap_vals,
269
- raw_freq=raw_freq_vec,
270
- scaled_vec=scaled_data[0],
271
- kmer_list=kmers,
272
- title=f"{header[:25]}... (Top-10 K-mers)"
273
- )
274
-
275
- # 9. Markdown result
276
- result_md = f"""
277
- # Classification Result
278
-
279
- **Header**: {header}
280
-
281
- **Predicted Label**: {pred_label}
282
- **Confidence**: {confidence:.4f}
283
-
284
- **Human Probability**: {human_prob:.4f}
285
- **Non-human Probability**: {nonhuman_prob:.4f}
286
 
287
- Above are the SHAP-based analyses (class=1).
288
- """
289
 
290
- return result_md, wf_img, freq_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
 
292
 
293
  ###############################################################################
294
  # Gradio Interface
295
  ###############################################################################
296
- with gr.Blocks(title="Single-Sequence Virus Host Classifier") as demo:
297
- gr.Markdown("## Upload a FASTA file containing **one** (or more) sequences. We only use the **first**.")
298
-
299
- file_input = gr.File(label="Upload FASTA", type="binary")
300
- run_btn = gr.Button("Classify & Explain")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
 
302
  with gr.Tabs():
303
- with gr.Tab("Results"):
304
  md_out = gr.Markdown()
305
- with gr.Tab("SHAP Waterfall"):
306
- wf_out = gr.Image(label="Waterfall Plot (class=1)")
307
- with gr.Tab("Top-10 K-mer Plot"):
308
- freq_out = gr.Image(label="Frequency & Sigma (class=1)")
309
-
310
- run_btn.click(
311
- fn=classify_and_explain,
312
- inputs=[file_input],
313
- outputs=[md_out, wf_out, freq_out]
 
 
 
 
 
314
  )
315
 
316
- # No share=True -> avoid HF Spaces warning
317
  if __name__ == "__main__":
318
- demo.launch()
 
 
2
  import torch
3
  import joblib
4
  import numpy as np
5
+ import shap
6
+ import random
7
  from itertools import product
8
  import torch.nn as nn
 
 
9
  import matplotlib.pyplot as plt
10
  import io
11
  from PIL import Image
 
12
 
13
  ###############################################################################
14
  # Model Definition
 
32
 
33
  def forward(self, x):
34
  return self.network(x)
35
+
36
+ def get_feature_importance(self, x):
37
+ """
38
+ Calculate gradient-based feature importance, specifically for the
39
+ 'human' class (index=1) by computing gradient of that probability wrt x.
40
+ """
41
+ x.requires_grad_(True)
42
+ output = self.network(x)
43
+ probs = torch.softmax(output, dim=1)
44
+
45
+ # Probability of 'human' class (index=1)
46
+ human_prob = probs[..., 1]
47
+ if x.grad is not None:
48
+ x.grad.zero_()
49
+ human_prob.backward()
50
+ importance = x.grad # shape: (batch_size, n_features)
51
+
52
+ return importance, float(human_prob)
 
 
53
 
54
  ###############################################################################
55
  # Utility Functions
56
  ###############################################################################
57
  def parse_fasta(text):
58
+ """Parses text input in FASTA format into a list of (header, sequence)."""
 
 
 
59
  sequences = []
60
  current_header = None
61
  current_sequence = []
 
76
  return sequences
77
 
78
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
79
+ """Convert a single nucleotide sequence to a k-mer frequency vector."""
 
 
80
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
81
  kmer_dict = {km: i for i, km in enumerate(kmers)}
82
  vec = np.zeros(len(kmers), dtype=np.float32)
 
88
 
89
  total_kmers = len(sequence) - k + 1
90
  if total_kmers > 0:
91
+ vec = vec / total_kmers # normalize frequencies
92
 
93
  return vec
94
 
95
  ###############################################################################
96
+ # Additional Plots
97
  ###############################################################################
98
+ def create_probability_bar_plot(prob_human, prob_nonhuman):
99
  """
100
+ Simple bar plot comparing human vs. non-human probabilities.
 
 
 
101
  """
102
+ labels = ["Non-human", "Human"]
103
+ probs = [prob_nonhuman, prob_human]
104
+ colors = ["red", "green"]
105
+
106
+ fig, ax = plt.subplots(figsize=(6, 4))
107
+ ax.bar(labels, probs, color=colors, alpha=0.7)
108
+ ax.set_ylim(0, 1)
109
+ for i, v in enumerate(probs):
110
+ ax.text(i, v+0.02, f"{v:.3f}", ha='center', color='black', fontsize=11)
111
+
112
+ ax.set_title("Predicted Probabilities")
113
+ ax.set_ylabel("Probability")
114
+ plt.tight_layout()
115
+ return fig
 
 
116
 
117
+ def create_frequency_sigma_plot(important_kmers, title):
118
  """
119
+ Creates a bar plot of the top k-mers (by importance) showing
120
+ frequency (%) and σ from mean.
 
 
 
121
  """
122
+ # Sort by absolute impact
123
+ sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
124
+ kmers = [k["kmer"] for k in sorted_kmers]
125
+ frequencies = [k["occurrence"] for k in sorted_kmers] # in %
126
+ sigmas = [k["sigma"] for k in sorted_kmers]
127
+ directions = [k["direction"] for k in sorted_kmers]
128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
129
  x = np.arange(len(kmers))
130
  width = 0.4
131
 
132
+ fig, ax_bar = plt.subplots(figsize=(10, 5))
133
 
134
+ # Bar for frequency
135
+ bars_freq = ax_bar.bar(
136
+ x - width/2, frequencies, width, alpha=0.7,
137
+ color=["green" if d=="human" else "red" for d in directions],
138
+ label="Frequency (%)"
139
+ )
140
+ ax_bar.set_ylabel("Frequency (%)")
141
+ ax_bar.set_ylim(0, max(frequencies) * 1.2 if len(frequencies) > 0 else 1)
142
+
143
+ # Twin axis for σ
144
+ ax_bar_twin = ax_bar.twinx()
145
+ bars_sigma = ax_bar_twin.bar(
146
+ x + width/2, sigmas, width, alpha=0.5, color="gray", label="σ from Mean"
147
+ )
148
+ ax_bar_twin.set_ylabel("Standard Deviations (σ)")
149
+
150
+ ax_bar.set_title(f"Frequency & σ from Mean for Top k-mers — {title}")
151
+ ax_bar.set_xticks(x)
152
+ ax_bar.set_xticklabels(kmers, rotation=45, ha='right')
153
+
154
+ # Combined legend
155
+ lines1, labels1 = ax_bar.get_legend_handles_labels()
156
+ lines2, labels2 = ax_bar_twin.get_legend_handles_labels()
157
+ ax_bar.legend(lines1 + lines2, labels1 + labels2, loc="upper right")
158
 
159
+ plt.tight_layout()
160
+ return fig
 
161
 
162
+ def create_importance_bar_plot(important_kmers, title):
163
+ """
164
+ Create a simple bar chart showing the absolute gradient magnitude
165
+ for the top k-mers, sorted descending.
166
+ """
167
+ sorted_kmers = sorted(important_kmers, key=lambda x: x['impact'], reverse=True)
168
+ kmers = [k['kmer'] for k in sorted_kmers]
169
+ impacts = [k['impact'] for k in sorted_kmers]
170
+ directions = [k["direction"] for k in sorted_kmers]
171
+
172
+ x = np.arange(len(kmers))
173
+
174
+ fig, ax = plt.subplots(figsize=(10, 5))
175
+ bar_colors = ["green" if d=="human" else "red" for d in directions]
176
+
177
+ ax.bar(x, impacts, color=bar_colors, alpha=0.7, edgecolor='black')
178
  ax.set_xticks(x)
179
  ax.set_xticklabels(kmers, rotation=45, ha='right')
180
+ ax.set_title(f"Absolute Feature Importance (Top k-mers) — {title}")
181
+ ax.set_ylabel("Gradient Magnitude")
182
+ ax.grid(axis="y", alpha=0.3)
 
 
183
 
184
  plt.tight_layout()
185
+ return fig
 
 
 
 
 
186
 
187
+ ###############################################################################
188
+ # SHAP Beeswarm
189
+ ###############################################################################
190
+ def create_shap_beeswarm_plot(
191
+ model,
192
+ input_vector: np.ndarray,
193
+ background_data: np.ndarray,
194
+ feature_names: list
195
+ ):
196
+ """
197
+ Creates a SHAP beeswarm plot using KernelExplainer for the given model and data.
198
+
199
+ Parameters
200
+ ----------
201
+ model : nn.Module
202
+ Trained PyTorch model (binary classifier).
203
+ input_vector : np.ndarray
204
+ The 1-sample input (or multiple samples) we want SHAP values for.
205
+ background_data : np.ndarray
206
+ Background samples for KernelExplainer. Should have shape (N, #features).
207
+ feature_names : list
208
+ Names for each feature (k-mers).
209
+
210
+ Returns
211
+ -------
212
+ fig : matplotlib Figure
213
+ Beeswarm plot figure.
214
+ """
215
+
216
+ # We'll define a prediction function that shap can call
217
+ # The model outputs logits for shape [N, 2]
218
+ # We want the raw outputs for each class. SHAP will handle the link function if needed.
219
+ def predict_fn(data):
220
+ """
221
+ data: shape (N, #features)
222
+ returns: shape (N, 2) for 2-class logits
223
+ """
224
+ with torch.no_grad():
225
+ x = torch.FloatTensor(data)
226
+ logits = model(x)
227
+ return logits.detach().cpu().numpy()
228
+
229
+ # Create KernelExplainer
230
+ explainer = shap.KernelExplainer(
231
+ model=predict_fn,
232
+ data=background_data
233
+ )
234
+
235
+ # Compute SHAP values
236
+ # For a 2-class model, shap_values is a list of length 2 => [class0 array, class1 array]
237
+ # Each array is shape (N, #features).
238
+ shap_values = explainer.shap_values(input_vector)
239
+
240
+ # We’ll produce a beeswarm for the 'human' class (class index=1).
241
+ # If we have only 1 sample, the beeswarm won't be too interesting, but let's do it anyway.
242
+ class_idx = 1 # 'human'
243
+
244
+ # If we only have one sample, place it in an array for shap summary plotting:
245
+ # We can do shap_values[class_idx].shape => (1, #features) for a single sample
246
+ # Beeswarm typically expects multiple samples. We'll plot anyway.
247
+ shap.plots.beeswarm(
248
+ shap_values[class_idx],
249
+ feature_names=feature_names,
250
+ show=False
251
+ )
252
+
253
+ fig = plt.gcf()
254
+ fig.set_size_inches(8, 6)
255
+ plt.title("SHAP Beeswarm Plot (Class: Human)")
256
+
257
+ plt.tight_layout()
258
+ return fig
259
 
260
  ###############################################################################
261
+ # Prediction Function
262
  ###############################################################################
263
+ def predict(file_obj):
264
  """
265
+ Main function for Gradio:
266
+ 1. Reads the uploaded FASTA file or text.
267
+ 2. Loads the model and scaler.
268
+ 3. Generates predictions, probabilities, and top k-mers.
269
+ 4. Creates multiple outputs:
270
+ - Text summary (Markdown)
271
+ - Probability Bar Plot
272
+ - SHAP Beeswarm Plot
273
+ - Frequency & σ Plot
274
+ - Absolute Feature Importance Bar Plot
275
  """
276
+ # 0. Basic file read
277
+ if file_obj is None:
278
+ return (
279
+ "Please upload a FASTA file.",
280
+ None,
281
+ None,
282
+ None,
283
+ None
284
+ )
285
+
286
  try:
287
  if isinstance(file_obj, str):
288
  text = file_obj
289
  else:
290
+ text = file_obj.decode('utf-8')
291
+ except Exception as e:
292
+ return (
293
+ f"Error reading file: {str(e)}",
294
+ None,
295
+ None,
296
+ None,
297
+ None
298
+ )
299
 
300
+ # 1. Parse FASTA
301
  sequences = parse_fasta(text)
302
+ if len(sequences) == 0:
303
+ return (
304
+ "No valid FASTA sequences found. Please check your input.",
305
+ None,
306
+ None,
307
+ None,
308
+ None
309
+ )
310
+ header, seq = sequences[0] # We'll classify only the first sequence
311
 
312
+ # 2. Prepare model, scaler, and input
313
  k = 4
 
 
 
314
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
315
  try:
316
+ raw_freq_vector = sequence_to_kmer_vector(seq, k=k)
317
+
318
+ # Load model & scaler
319
+ model = VirusClassifier(input_shape=4**k).to(device)
320
+ state_dict = torch.load("model.pt", map_location=device)
321
  model.load_state_dict(state_dict)
 
322
  scaler = joblib.load("scaler.pkl")
323
+ model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
+ scaled_vector = scaler.transform(raw_freq_vector.reshape(1, -1))
326
+ X_tensor = torch.FloatTensor(scaled_vector).to(device)
327
 
328
+ # 3. Predict
329
+ with torch.no_grad():
330
+ logits = model(X_tensor)
331
+ probs = torch.softmax(logits, dim=1)
332
+ human_prob = float(probs[0][1])
333
+ non_human_prob = float(probs[0][0])
334
+ pred_label = "human" if human_prob >= non_human_prob else "non-human"
335
+ confidence = float(max(probs[0]))
336
+
337
+ # 4. Gradient-based feature importance
338
+ importance, hum_prob_grad = model.get_feature_importance(X_tensor)
339
+ importances = importance[0].cpu().numpy() # shape: (#features,)
340
+ abs_importances = np.abs(importances)
341
+
342
+ # 5. Gather k-mer strings
343
+ kmers_list = [''.join(p) for p in product("ACGT", repeat=k)]
344
+ # top 10 by absolute importance
345
+ top_k = 10
346
+ top_idxs = np.argsort(abs_importances)[-top_k:][::-1]
347
+ important_kmers = []
348
+ for idx in top_idxs:
349
+ direction = "human" if importances[idx] > 0 else "non-human"
350
+ freq_percent = float(raw_freq_vector[idx] * 100.0)
351
+ sigma_val = float(scaled_vector[0][idx]) # scaled / standardized val
352
+ important_kmers.append({
353
+ 'kmer': kmers_list[idx],
354
+ 'idx': idx,
355
+ 'impact': abs_importances[idx],
356
+ 'direction': direction,
357
+ 'occurrence': freq_percent,
358
+ 'sigma': sigma_val
359
+ })
360
+
361
+ # 6. Generate text summary
362
+ text_summary = (
363
+ f"**Sequence Header**: {header}\n\n"
364
+ f"**Predicted Label**: {pred_label}\n"
365
+ f"**Confidence**: {confidence:.4f}\n\n"
366
+ f"**Human Probability**: {human_prob:.4f}\n"
367
+ f"**Non-human Probability**: {non_human_prob:.4f}\n\n"
368
+ "### Most Influential k-mers:\n"
369
+ )
370
+ for km in important_kmers:
371
+ direction_text = f"(pushes toward {km['direction']})"
372
+ freq_text = f"{km['occurrence']:.2f}%"
373
+ sigma_text = (
374
+ f"{abs(km['sigma']):.2f}σ "
375
+ + ("above" if km['sigma'] > 0 else "below")
376
+ + " mean"
377
+ )
378
+ text_summary += (
379
+ f"- **{km['kmer']}**: impact={km['impact']:.4f}, {direction_text}, "
380
+ f"occurrence={freq_text}, ({sigma_text})\n"
381
+ )
382
+
383
+ # 7. Probability Bar Plot
384
+ fig_prob = create_probability_bar_plot(human_prob, non_human_prob)
385
+ buf_prob = io.BytesIO()
386
+ fig_prob.savefig(buf_prob, format='png', bbox_inches='tight', dpi=120)
387
+ buf_prob.seek(0)
388
+ prob_img = Image.open(buf_prob)
389
+ plt.close(fig_prob)
390
+
391
+ # 8. SHAP Beeswarm Plot
392
+ # We need some background data for KernelExplainer. Let's create a small random sample
393
+ # or sample from the scaled_vector itself in a repeated manner. Real usage: choose a valid background set.
394
+ background_size = 5 # keep small for speed
395
+ # We'll pick random sequences from normal(0,1) or from scaled_vector repeated
396
+ background_data = []
397
+ for _ in range(background_size):
398
+ # Option A: random small variations around scaled_vector
399
+ # new_sample = scaled_vector[0] + np.random.normal(0, 0.5, size=scaled_vector.shape[1])
400
+ # Option B: just clone the same scaled vector multiple times
401
+ new_sample = scaled_vector[0]
402
+ background_data.append(new_sample)
403
+ background_data = np.stack(background_data, axis=0) # shape (5, #features)
404
+
405
+ fig_bee = create_shap_beeswarm_plot(
406
+ model=model,
407
+ input_vector=scaled_vector, # our single sample
408
+ background_data=background_data, # background for KernelExplainer
409
+ feature_names=kmers_list
410
+ )
411
+ buf_bee = io.BytesIO()
412
+ fig_bee.savefig(buf_bee, format='png', bbox_inches='tight', dpi=120)
413
+ buf_bee.seek(0)
414
+ bee_img = Image.open(buf_bee)
415
+ plt.close(fig_bee)
416
+
417
+ # 9. Frequency & σ Plot
418
+ fig_freq = create_frequency_sigma_plot(important_kmers, header)
419
+ buf_freq = io.BytesIO()
420
+ fig_freq.savefig(buf_freq, format='png', bbox_inches='tight', dpi=120)
421
+ buf_freq.seek(0)
422
+ freq_img = Image.open(buf_freq)
423
+ plt.close(fig_freq)
424
+
425
+ # 10. Absolute Feature Importance Bar Plot
426
+ fig_imp = create_importance_bar_plot(important_kmers, header)
427
+ buf_imp = io.BytesIO()
428
+ fig_imp.savefig(buf_imp, format='png', bbox_inches='tight', dpi=120)
429
+ buf_imp.seek(0)
430
+ imp_img = Image.open(buf_imp)
431
+ plt.close(fig_imp)
432
+
433
+ return text_summary, prob_img, bee_img, freq_img, imp_img
434
+
435
+ except Exception as e:
436
+ return (
437
+ f"Error during prediction or visualization: {str(e)}",
438
+ None,
439
+ None,
440
+ None,
441
+ None
442
+ )
443
 
444
 
445
  ###############################################################################
446
  # Gradio Interface
447
  ###############################################################################
448
+ with gr.Blocks(title="Advanced Virus Host Classifier with SHAP Beeswarm") as demo:
449
+ gr.Markdown(
450
+ """
451
+ # Advanced Virus Host Classifier (SHAP Beeswarm Edition)
452
+
453
+ **Upload a FASTA file** containing a single nucleotide sequence.
454
+ The model will predict whether this sequence is **human** or **non-human**,
455
+ provide a confidence score, and highlight the most influential k-mers.
456
+ We also produce a **SHAP beeswarm** plot for the features.
457
+
458
+ ---
459
+ **Note**: Beeswarm plots are usually most insightful with multiple samples.
460
+ Here, we demonstrate usage with a single sample plus a small synthetic background.
461
+ """
462
+ )
463
+
464
+ with gr.Row():
465
+ file_in = gr.File(label="Upload FASTA", type="binary")
466
+ btn = gr.Button("Run Prediction")
467
 
468
+ # We will create multiple tabs for our outputs
469
  with gr.Tabs():
470
+ with gr.Tab("Prediction Results"):
471
  md_out = gr.Markdown()
472
+ with gr.Tab("Probability Plot"):
473
+ prob_out = gr.Image()
474
+ with gr.Tab("SHAP Beeswarm Plot"):
475
+ bee_out = gr.Image()
476
+ with gr.Tab("Frequency & σ Plot"):
477
+ freq_out = gr.Image()
478
+ with gr.Tab("Importance Bar Plot"):
479
+ imp_out = gr.Image()
480
+
481
+ # Link the button
482
+ btn.click(
483
+ fn=predict,
484
+ inputs=[file_in],
485
+ outputs=[md_out, prob_out, bee_out, freq_out, imp_out]
486
  )
487
 
 
488
  if __name__ == "__main__":
489
+ # By default, share=False. You can set share=True for external access.
490
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)