import gradio as gr import torch import joblib import numpy as np from itertools import product import torch.nn as nn import matplotlib matplotlib.use("Agg") # In case we're running in a no-display environment import matplotlib.pyplot as plt import io from PIL import Image import shap ############################################################################### # Model Definition ############################################################################### class VirusClassifier(nn.Module): def __init__(self, input_shape: int): super(VirusClassifier, self).__init__() self.network = nn.Sequential( nn.Linear(input_shape, 64), nn.GELU(), nn.BatchNorm1d(64), nn.Dropout(0.3), nn.Linear(64, 32), nn.GELU(), nn.BatchNorm1d(32), nn.Dropout(0.3), nn.Linear(32, 32), nn.GELU(), nn.Linear(32, 2) ) def forward(self, x): return self.network(x) ############################################################################### # Torch Model Wrapper for SHAP ############################################################################### class TorchModelWrapper: """ A simple callable that takes a PyTorch model and device, allowing SHAP to pass in NumPy arrays. We convert them to torch tensors, run the model, and return NumPy outputs. """ def __init__(self, model: nn.Module, device='cpu'): self.model = model self.device = device def __call__(self, x_np: np.ndarray): """ x_np: shape=(batch_size, num_features) as a numpy array Returns: numpy array of shape=(batch_size, num_outputs) """ x_torch = torch.from_numpy(x_np).float().to(self.device) with torch.no_grad(): out = self.model(x_torch).cpu().numpy() return out ############################################################################### # Utility Functions ############################################################################### def parse_fasta(text): """ Parses text input in FASTA format into a list of (header, sequence). Handles multiple sequences if present. """ sequences = [] current_header = None current_sequence = [] for line in text.split('\n'): line = line.strip() if not line: continue if line.startswith('>'): if current_header: sequences.append((current_header, ''.join(current_sequence))) current_header = line[1:] current_sequence = [] else: current_sequence.append(line.upper()) if current_header: sequences.append((current_header, ''.join(current_sequence))) return sequences def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray: """ Convert a single nucleotide sequence to a k-mer frequency vector of length 4^k (e.g., for k=4, length=256). """ kmers = [''.join(p) for p in product("ACGT", repeat=k)] kmer_dict = {km: i for i, km in enumerate(kmers)} vec = np.zeros(len(kmers), dtype=np.float32) for i in range(len(sequence) - k + 1): kmer = sequence[i:i+k] if kmer in kmer_dict: vec[kmer_dict[kmer]] += 1 total_kmers = len(sequence) - k + 1 if total_kmers > 0: vec = vec / total_kmers # normalize frequencies return vec ############################################################################### # Visualization Helpers ############################################################################### def create_freq_sigma_plot( single_shap_values: np.ndarray, raw_freq_vector: np.ndarray, scaled_vector: np.ndarray, kmer_list, title: str ): """ Creates a bar plot showing top-10 k-mers (by absolute SHAP value), with frequency (%) and sigma from mean on a twin-axis. single_shap_values: shape=(256,) SHAP values for the "human" class raw_freq_vector: shape=(256,) original frequencies for this sample scaled_vector: shape=(256,) scaled (Z-score) values for this sample kmer_list: list of length=256 of all k-mers """ # Identify the top 10 k-mers by absolute shap abs_vals = np.abs(single_shap_values) # shape=(256,) top_k = 10 top_indices = np.argsort(abs_vals)[-top_k:][::-1] # indices of largest -> smallest top_data = [] for idx in top_indices: idx_int = int(idx) # ensure integer top_data.append({ "kmer": kmer_list[idx_int], "shap": single_shap_values[idx_int], "abs_shap": abs_vals[idx_int], "frequency": raw_freq_vector[idx_int] * 100.0, # percentage "sigma": scaled_vector[idx_int] }) # Sort top_data by abs_shap descending top_data.sort(key=lambda x: x["abs_shap"], reverse=True) # Prepare for plotting kmers = [d["kmer"] for d in top_data] freqs = [d["frequency"] for d in top_data] sigmas = [d["sigma"] for d in top_data] # color by sign (positive=green => pushes "human", negative=red => pushes "non-human") colors = ["green" if d["shap"] >= 0 else "red" for d in top_data] x = np.arange(len(kmers)) width = 0.4 fig, ax = plt.subplots(figsize=(8, 5)) # Frequency ax.bar( x - width/2, freqs, width, color=colors, alpha=0.7, label="Frequency (%)" ) ax.set_ylabel("Frequency (%)", color='black') if len(freqs) > 0: ax.set_ylim(0, max(freqs)*1.2) # Twin axis for sigma ax2 = ax.twinx() ax2.bar( x + width/2, sigmas, width, color="gray", alpha=0.5, label="σ from Mean" ) ax2.set_ylabel("Standard Deviations (σ)", color='black') ax.set_xticks(x) ax.set_xticklabels(kmers, rotation=45, ha='right') ax.set_title(f"Top-10 K-mers (Frequency & σ)\n{title}") # Combine legends lines1, labels1 = ax.get_legend_handles_labels() lines2, labels2 = ax2.get_legend_handles_labels() ax.legend(lines1 + lines2, labels1 + labels2, loc='upper right') plt.tight_layout() return fig ############################################################################### # Main Inference & SHAP Logic ############################################################################### def run_classification_and_shap(file_obj): """ Reads one or more FASTA sequences from file_obj or text. Returns: - Table of results (list of dicts) for each sequence - shap_values object (SHAP values for the entire batch, shape=(num_samples, 2, num_features)) - array of scaled vectors - list of k-mers - error message or None """ # 1. Basic read if isinstance(file_obj, str): text = file_obj else: try: text = file_obj.decode("utf-8") except Exception as e: return None, None, None, None, f"Error reading file: {str(e)}" # 2. Parse FASTA sequences = parse_fasta(text) if len(sequences) == 0: return None, None, None, None, "No valid FASTA sequences found!" # 3. Convert each sequence to k-mer vector k = 4 all_raw_vectors = [] headers = [] seqs = [] for (hdr, seq) in sequences: raw_vec = sequence_to_kmer_vector(seq, k=k) all_raw_vectors.append(raw_vec) headers.append(hdr) seqs.append(seq) all_raw_vectors = np.stack(all_raw_vectors, axis=0) # shape=(num_seqs, 256) # 4. Load model & scaler try: device = "cuda" if torch.cuda.is_available() else "cpu" model = VirusClassifier(input_shape=4**k).to(device) # Use weights_only=True to suppress future warnings about untrusted pickles state_dict = torch.load("model.pt", map_location=device, weights_only=True) model.load_state_dict(state_dict) model.eval() scaler = joblib.load("scaler.pkl") except Exception as e: return None, None, None, None, f"Error loading model or scaler: {str(e)}" # 5. Scale data scaled_data = scaler.transform(all_raw_vectors) # shape=(num_seqs, 256) # 6. Predictions X_tensor = torch.FloatTensor(scaled_data).to(device) with torch.no_grad(): logits = model(X_tensor) # shape=(num_seqs, 2) probs = torch.softmax(logits, dim=1).cpu().numpy() preds = np.argmax(probs, axis=1) # 0 or 1 results_table = [] for i, (hdr, seq) in enumerate(zip(headers, seqs)): results_table.append({ "header": hdr, "sequence": seq[:50] + ("..." if len(seq) > 50 else ""), "pred_label": "human" if preds[i] == 1 else "non-human", "human_prob": float(probs[i][1]), "non_human_prob": float(probs[i][0]), "confidence": float(np.max(probs[i])) }) # 7. SHAP Explainer # For large data, pick a smaller background subset if scaled_data.shape[0] > 50: background_data = scaled_data[:50] else: background_data = scaled_data wrapped_model = TorchModelWrapper(model, device) explainer = shap.Explainer(wrapped_model, background_data) # shap_values shape=(num_samples, num_features) if single-output # but here we have 2 outputs => shape=(num_samples, 2, num_features). shap_values = explainer(scaled_data) # Prepare k-mer list kmer_list = [''.join(p) for p in product("ACGT", repeat=k)] # Return everything return (results_table, shap_values, scaled_data, kmer_list, None) ############################################################################### # Gradio Callback Functions ############################################################################### def main_predict(file_obj): """ Triggered by the 'Run Classification' button in Gradio. Returns a markdown table plus states for subsequent plots. """ results, shap_vals, scaled_data, kmer_list, err = run_classification_and_shap(file_obj) if err: return (err, None, None, None, None) if results is None or shap_vals is None: return ("An unknown error occurred.", None, None, None, None) # Build a summary for all sequences md = "# Classification Results\n\n" md += "| # | Header | Pred Label | Confidence | Human Prob | Non-human Prob |\n" md += "|---|--------|------------|------------|------------|----------------|\n" for i, row in enumerate(results): md += ( f"| {i} | {row['header']} | {row['pred_label']} | " f"{row['confidence']:.4f} | {row['human_prob']:.4f} | {row['non_human_prob']:.4f} |\n" ) md += "\nSelect a sequence index below to view SHAP Waterfall & Frequency plots (class=1/human)." return (md, shap_vals, scaled_data, kmer_list, results) def update_waterfall_plot(selected_index, shap_values_obj): """ Build a waterfall plot for the user-selected sample, but ONLY for class=1 (human). shap_values_obj has shape=(num_samples, 2, num_features). We do shap_values_obj[selected_index, 1] => shape=(num_features,) for a single-sample single-class explanation. """ if shap_values_obj is None: return None import matplotlib.pyplot as plt try: selected_index = int(selected_index) except: selected_index = 0 # We only visualize class=1 ("human") SHAP values # shap_values_obj.values shape => (num_samples, 2, num_features) single_ex_values = shap_values_obj.values[selected_index, 1, :] # shape=(256,) single_ex_base = shap_values_obj.base_values[selected_index, 1] # scalar single_ex_data = shap_values_obj.data[selected_index] # shape=(256,) # Construct a shap.Explanation object for just this one sample & class single_expl = shap.Explanation( values=single_ex_values, base_values=single_ex_base, data=single_ex_data, feature_names=[f"feat_{i}" for i in range(single_ex_values.shape[0])] ) shap_plots_fig = plt.figure(figsize=(8, 5)) shap.plots.waterfall(single_expl, max_display=14, show=False) buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=120) buf.seek(0) wf_img = Image.open(buf) plt.close(shap_plots_fig) return wf_img def update_beeswarm_plot(shap_values_obj): """ Build a beeswarm plot across all samples, but only for class=1 (human). We slice shap_values_obj to pick shap_values_obj.values[:, 1, :] => shape=(num_samples, num_features). """ if shap_values_obj is None: return None import matplotlib.pyplot as plt # For multi-output, shap_values_obj.values shape => (num_samples, 2, num_features) # We'll create a new Explanation object for class=1: class1_vals = shap_values_obj.values[:, 1, :] # shape=(num_samples, num_features) class1_base = shap_values_obj.base_values[:, 1] # shape=(num_samples,) class1_data = shap_values_obj.data # shape=(num_samples, num_features) # Some versions of shap store data in a 2D array, which is fine # We'll re-wrap them in a shap.Explanation: class1_expl = shap.Explanation( values=class1_vals, base_values=class1_base, data=class1_data, feature_names=[f"feat_{i}" for i in range(class1_vals.shape[1])] ) beeswarm_fig = plt.figure(figsize=(8, 5)) shap.plots.beeswarm(class1_expl, show=False) buf = io.BytesIO() plt.savefig(buf, format='png', bbox_inches='tight', dpi=120) buf.seek(0) bs_img = Image.open(buf) plt.close(beeswarm_fig) return bs_img def update_freq_plot(selected_index, shap_values_obj, scaled_data, kmer_list, file_obj): """ Create the frequency & σ bar chart for the selected sequence's top-10 k-mers (by abs SHAP). Again, we'll use class=1 SHAP values only. """ if shap_values_obj is None or scaled_data is None or kmer_list is None: return None import matplotlib.pyplot as plt try: selected_index = int(selected_index) except: selected_index = 0 # Re-parse the FASTA to get the corresponding sequence if isinstance(file_obj, str): text = file_obj else: text = file_obj.decode('utf-8') sequences = parse_fasta(text) # If out of range, clamp to 0 if selected_index >= len(sequences): selected_index = 0 seq = sequences[selected_index][1] raw_vec = sequence_to_kmer_vector(seq, k=4) # shape=(256,) # SHAP for class=1 => shape=(num_samples, 2, 256) single_shap_values = shap_values_obj.values[selected_index, 1, :] freq_sigma_fig = create_freq_sigma_plot( single_shap_values, raw_freq_vector=raw_vec, scaled_vector=scaled_data[selected_index], kmer_list=kmer_list, title=f"Sample #{selected_index} — {sequences[selected_index][0]}" ) buf = io.BytesIO() freq_sigma_fig.savefig(buf, format='png', bbox_inches='tight', dpi=120) buf.seek(0) fs_img = Image.open(buf) plt.close(freq_sigma_fig) return fs_img ############################################################################### # Gradio Interface ############################################################################### with gr.Blocks(title="Multi-Sequence Virus Host Classifier with SHAP") as demo: shap.initjs() # load shap JS if needed for HTML-based plots (optional) gr.Markdown( """ # **irus Host Classifier** Upload a FASTA file with one or more nucleotide sequences. This app will: 1. Predict each sequence's **host** (human vs. non-human). 2. Provide **SHAP** explanations focusing on the 'human' class (index=1). 3. Display: - A **waterfall** plot per-sequence (top features). - A **beeswarm** plot across all sequences (global summary). - A **frequency & σ** bar chart for the top-10 k-mers of any selected sequence. """ ) with gr.Row(): file_input = gr.File(label="Upload FASTA", type="binary") run_btn = gr.Button("Run Classification") # Store intermediate results in Gradio states shap_values_state = gr.State() scaled_data_state = gr.State() kmer_list_state = gr.State() results_state = gr.State() file_data_state = gr.State() with gr.Tabs(): with gr.Tab("Results Table"): md_out = gr.Markdown() with gr.Tab("SHAP Waterfall"): with gr.Row(): seq_index_input = gr.Number(label="Sequence Index (0-based)", value=0, precision=0) update_wf_btn = gr.Button("Update Waterfall") wf_plot = gr.Image(label="SHAP Waterfall Plot") with gr.Tab("SHAP Beeswarm"): bs_plot = gr.Image(label="Global Beeswarm Plot", height=500) with gr.Tab("Top-10 Frequency & Sigma"): with gr.Row(): seq_index_input2 = gr.Number(label="Sequence Index (0-based)", value=0, precision=0) update_fs_btn = gr.Button("Update Frequency Chart") fs_plot = gr.Image(label="Top-10 Frequency & σ Chart") # 1) Main classification run_btn.click( fn=main_predict, inputs=[file_input], outputs=[md_out, shap_values_state, scaled_data_state, kmer_list_state, results_state] ) run_btn.click( fn=lambda x: x, inputs=file_input, outputs=file_data_state ) # 2) Update Waterfall update_wf_btn.click( fn=update_waterfall_plot, inputs=[seq_index_input, shap_values_state], outputs=[wf_plot] ) # 3) Update Beeswarm right after classification run_btn.click( fn=update_beeswarm_plot, inputs=[shap_values_state], outputs=[bs_plot] ) # 4) Update Frequency & σ update_fs_btn.click( fn=update_freq_plot, inputs=[seq_index_input2, shap_values_state, scaled_data_state, kmer_list_state, file_data_state], outputs=[fs_plot] ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860, share=True)