Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -148,20 +148,17 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
|
|
148 |
avg_val = np.mean(shap_means) if n > 0 else 0.0
|
149 |
return (0, n, avg_val)
|
150 |
|
151 |
-
#
|
152 |
-
csum = np.cumsum(shap_means)
|
153 |
-
#
|
154 |
def window_sum(start):
|
155 |
end = start + window_size
|
156 |
return csum[end] - csum[start]
|
157 |
|
158 |
best_start = 0
|
159 |
-
best_avg = None
|
160 |
-
|
161 |
# Initialize the best with the first window
|
162 |
best_sum = window_sum(0)
|
163 |
best_avg = best_sum / window_size
|
164 |
-
best_start = 0
|
165 |
|
166 |
for start in range(1, n - window_size + 1):
|
167 |
wsum = window_sum(start)
|
@@ -195,7 +192,10 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
|
|
195 |
Plots a 1D heatmap of per-base SHAP contributions.
|
196 |
Negative = push toward Non-Human, Positive = push toward Human.
|
197 |
Optionally can show only a subrange (start:end).
|
198 |
-
|
|
|
|
|
|
|
199 |
"""
|
200 |
if start is not None and end is not None:
|
201 |
shap_means = shap_means[start:end]
|
@@ -208,16 +208,16 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
|
|
208 |
fig, ax = plt.subplots(figsize=(12, 2))
|
209 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
|
210 |
|
211 |
-
#
|
212 |
-
|
213 |
-
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25)
|
214 |
cbar.set_label('SHAP Contribution')
|
215 |
|
216 |
ax.set_yticks([])
|
217 |
ax.set_xlabel('Position in Sequence')
|
218 |
ax.set_title(f"{title}{subtitle}")
|
219 |
-
|
220 |
-
|
|
|
221 |
|
222 |
return fig
|
223 |
|
@@ -280,14 +280,14 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
280 |
with open(file_obj, 'r') as f:
|
281 |
text = f.read()
|
282 |
except Exception as e:
|
283 |
-
return (f"Error reading file: {str(e)}", None, None, None, None
|
284 |
else:
|
285 |
-
return ("Please provide a FASTA sequence.", None, None, None, None
|
286 |
|
287 |
# Parse FASTA
|
288 |
sequences = parse_fasta(text)
|
289 |
if not sequences:
|
290 |
-
return ("No valid FASTA sequences found.", None, None, None, None
|
291 |
|
292 |
header, seq = sequences[0]
|
293 |
|
@@ -298,7 +298,7 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
298 |
model.load_state_dict(torch.load('model.pt', map_location=device))
|
299 |
scaler = joblib.load('scaler.pkl')
|
300 |
except Exception as e:
|
301 |
-
return (f"Error loading model: {str(e)}", None, None, None, None
|
302 |
|
303 |
# Vectorize + scale
|
304 |
freq_vector = sequence_to_kmer_vector(seq)
|
@@ -343,20 +343,14 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
|
|
343 |
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
|
344 |
heatmap_img = fig_to_image(heatmap_fig)
|
345 |
|
346 |
-
#
|
347 |
-
# 1) results text
|
348 |
-
# 2) k-mer bar image
|
349 |
-
# 3) full-genome heatmap
|
350 |
-
# 4) "state" with { seq, shap_means, header }, for subregion analysis
|
351 |
-
# 5) we also return "most pushing" subregion info if we want
|
352 |
-
# but for simplicity, we can just keep them in the text.
|
353 |
-
# 6) the sequence header
|
354 |
state_dict = {
|
355 |
"seq": seq,
|
356 |
"shap_means": shap_means
|
357 |
}
|
358 |
|
359 |
-
|
|
|
360 |
|
361 |
###############################################################################
|
362 |
# 8. SUBREGION ANALYSIS (Gradio Step 2)
|
@@ -481,21 +475,20 @@ with gr.Blocks(css=css) as iface:
|
|
481 |
kmer_img = gr.Image(label="Top k-mer SHAP")
|
482 |
genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
|
483 |
|
484 |
-
#
|
485 |
seq_state = gr.State()
|
486 |
header_state = gr.State()
|
487 |
|
488 |
-
#
|
489 |
# 1) results_text
|
490 |
# 2) bar_img
|
491 |
# 3) heatmap_img
|
492 |
# 4) state_dict
|
493 |
# 5) header
|
494 |
-
# 6) None placeholder
|
495 |
analyze_btn.click(
|
496 |
analyze_sequence,
|
497 |
inputs=[file_input, top_k, text_input, win_size],
|
498 |
-
outputs=[results_box, kmer_img, genome_img, seq_state, header_state
|
499 |
)
|
500 |
|
501 |
with gr.Tab("2) Subregion Exploration"):
|
|
|
148 |
avg_val = np.mean(shap_means) if n > 0 else 0.0
|
149 |
return (0, n, avg_val)
|
150 |
|
151 |
+
# For efficiency, we can do a rolling sum approach
|
152 |
+
csum = np.cumsum(shap_means)
|
153 |
+
# csum[i] = sum of shap_means[0..i-1]
|
154 |
def window_sum(start):
|
155 |
end = start + window_size
|
156 |
return csum[end] - csum[start]
|
157 |
|
158 |
best_start = 0
|
|
|
|
|
159 |
# Initialize the best with the first window
|
160 |
best_sum = window_sum(0)
|
161 |
best_avg = best_sum / window_size
|
|
|
162 |
|
163 |
for start in range(1, n - window_size + 1):
|
164 |
wsum = window_sum(start)
|
|
|
192 |
Plots a 1D heatmap of per-base SHAP contributions.
|
193 |
Negative = push toward Non-Human, Positive = push toward Human.
|
194 |
Optionally can show only a subrange (start:end).
|
195 |
+
|
196 |
+
We adjust layout so the colorbar is well below the x-axis:
|
197 |
+
- orientation='horizontal', pad=0.35
|
198 |
+
- plt.subplots_adjust(bottom=0.4)
|
199 |
"""
|
200 |
if start is not None and end is not None:
|
201 |
shap_means = shap_means[start:end]
|
|
|
208 |
fig, ax = plt.subplots(figsize=(12, 2))
|
209 |
cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
|
210 |
|
211 |
+
# Place colorbar below and add extra margin
|
212 |
+
cbar = plt.colorbar(cax, orientation='horizontal', pad=0.35)
|
|
|
213 |
cbar.set_label('SHAP Contribution')
|
214 |
|
215 |
ax.set_yticks([])
|
216 |
ax.set_xlabel('Position in Sequence')
|
217 |
ax.set_title(f"{title}{subtitle}")
|
218 |
+
|
219 |
+
# Extra bottom margin so colorbar won't overlap x-axis labels
|
220 |
+
plt.subplots_adjust(bottom=0.4)
|
221 |
|
222 |
return fig
|
223 |
|
|
|
280 |
with open(file_obj, 'r') as f:
|
281 |
text = f.read()
|
282 |
except Exception as e:
|
283 |
+
return (f"Error reading file: {str(e)}", None, None, None, None)
|
284 |
else:
|
285 |
+
return ("Please provide a FASTA sequence.", None, None, None, None)
|
286 |
|
287 |
# Parse FASTA
|
288 |
sequences = parse_fasta(text)
|
289 |
if not sequences:
|
290 |
+
return ("No valid FASTA sequences found.", None, None, None, None)
|
291 |
|
292 |
header, seq = sequences[0]
|
293 |
|
|
|
298 |
model.load_state_dict(torch.load('model.pt', map_location=device))
|
299 |
scaler = joblib.load('scaler.pkl')
|
300 |
except Exception as e:
|
301 |
+
return (f"Error loading model: {str(e)}", None, None, None, None)
|
302 |
|
303 |
# Vectorize + scale
|
304 |
freq_vector = sequence_to_kmer_vector(seq)
|
|
|
343 |
heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
|
344 |
heatmap_img = fig_to_image(heatmap_fig)
|
345 |
|
346 |
+
# Store data for subregion analysis
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
state_dict = {
|
348 |
"seq": seq,
|
349 |
"shap_means": shap_means
|
350 |
}
|
351 |
|
352 |
+
# We now return 5 items (not 6):
|
353 |
+
return (results_text, bar_img, heatmap_img, state_dict, header)
|
354 |
|
355 |
###############################################################################
|
356 |
# 8. SUBREGION ANALYSIS (Gradio Step 2)
|
|
|
475 |
kmer_img = gr.Image(label="Top k-mer SHAP")
|
476 |
genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
|
477 |
|
478 |
+
# State for step 2
|
479 |
seq_state = gr.State()
|
480 |
header_state = gr.State()
|
481 |
|
482 |
+
# analyze_sequence(...) now returns 5 items, so we have 5 outputs.
|
483 |
# 1) results_text
|
484 |
# 2) bar_img
|
485 |
# 3) heatmap_img
|
486 |
# 4) state_dict
|
487 |
# 5) header
|
|
|
488 |
analyze_btn.click(
|
489 |
analyze_sequence,
|
490 |
inputs=[file_input, top_k, text_input, win_size],
|
491 |
+
outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
|
492 |
)
|
493 |
|
494 |
with gr.Tab("2) Subregion Exploration"):
|