hiyata commited on
Commit
9a00943
·
verified ·
1 Parent(s): 552aec4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -29
app.py CHANGED
@@ -148,20 +148,17 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
148
  avg_val = np.mean(shap_means) if n > 0 else 0.0
149
  return (0, n, avg_val)
150
 
151
- # Rolling sum approach
152
- csum = np.cumsum(shap_means) # csum[i] = sum of shap_means[0..i-1]
153
- # function to compute sum in [start, start+window_size)
154
  def window_sum(start):
155
  end = start + window_size
156
  return csum[end] - csum[start]
157
 
158
  best_start = 0
159
- best_avg = None
160
-
161
  # Initialize the best with the first window
162
  best_sum = window_sum(0)
163
  best_avg = best_sum / window_size
164
- best_start = 0
165
 
166
  for start in range(1, n - window_size + 1):
167
  wsum = window_sum(start)
@@ -195,7 +192,10 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
195
  Plots a 1D heatmap of per-base SHAP contributions.
196
  Negative = push toward Non-Human, Positive = push toward Human.
197
  Optionally can show only a subrange (start:end).
198
- We'll adjust layout so that the colorbar is below the x-axis and doesn't overlap.
 
 
 
199
  """
200
  if start is not None and end is not None:
201
  shap_means = shap_means[start:end]
@@ -208,16 +208,16 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
208
  fig, ax = plt.subplots(figsize=(12, 2))
209
  cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
210
 
211
- # Adjust colorbar with some extra margin
212
- # We'll place the colorbar horizontally below
213
- cbar = plt.colorbar(cax, orientation='horizontal', pad=0.25)
214
  cbar.set_label('SHAP Contribution')
215
 
216
  ax.set_yticks([])
217
  ax.set_xlabel('Position in Sequence')
218
  ax.set_title(f"{title}{subtitle}")
219
- # Additional spacing at bottom to avoid overlap
220
- plt.subplots_adjust(bottom=0.3)
 
221
 
222
  return fig
223
 
@@ -280,14 +280,14 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
280
  with open(file_obj, 'r') as f:
281
  text = f.read()
282
  except Exception as e:
283
- return (f"Error reading file: {str(e)}", None, None, None, None, None)
284
  else:
285
- return ("Please provide a FASTA sequence.", None, None, None, None, None)
286
 
287
  # Parse FASTA
288
  sequences = parse_fasta(text)
289
  if not sequences:
290
- return ("No valid FASTA sequences found.", None, None, None, None, None)
291
 
292
  header, seq = sequences[0]
293
 
@@ -298,7 +298,7 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
298
  model.load_state_dict(torch.load('model.pt', map_location=device))
299
  scaler = joblib.load('scaler.pkl')
300
  except Exception as e:
301
- return (f"Error loading model: {str(e)}", None, None, None, None, None)
302
 
303
  # Vectorize + scale
304
  freq_vector = sequence_to_kmer_vector(seq)
@@ -343,20 +343,14 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
343
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
344
  heatmap_img = fig_to_image(heatmap_fig)
345
 
346
- # Return:
347
- # 1) results text
348
- # 2) k-mer bar image
349
- # 3) full-genome heatmap
350
- # 4) "state" with { seq, shap_means, header }, for subregion analysis
351
- # 5) we also return "most pushing" subregion info if we want
352
- # but for simplicity, we can just keep them in the text.
353
- # 6) the sequence header
354
  state_dict = {
355
  "seq": seq,
356
  "shap_means": shap_means
357
  }
358
 
359
- return (results_text, bar_img, heatmap_img, state_dict, header, None)
 
360
 
361
  ###############################################################################
362
  # 8. SUBREGION ANALYSIS (Gradio Step 2)
@@ -481,21 +475,20 @@ with gr.Blocks(css=css) as iface:
481
  kmer_img = gr.Image(label="Top k-mer SHAP")
482
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
483
 
484
- # Hidden states that store data for step 2
485
  seq_state = gr.State()
486
  header_state = gr.State()
487
 
488
- # The "analyze_sequence" function returns 6 values, which we map here:
489
  # 1) results_text
490
  # 2) bar_img
491
  # 3) heatmap_img
492
  # 4) state_dict
493
  # 5) header
494
- # 6) None placeholder
495
  analyze_btn.click(
496
  analyze_sequence,
497
  inputs=[file_input, top_k, text_input, win_size],
498
- outputs=[results_box, kmer_img, genome_img, seq_state, header_state, None]
499
  )
500
 
501
  with gr.Tab("2) Subregion Exploration"):
 
148
  avg_val = np.mean(shap_means) if n > 0 else 0.0
149
  return (0, n, avg_val)
150
 
151
+ # For efficiency, we can do a rolling sum approach
152
+ csum = np.cumsum(shap_means)
153
+ # csum[i] = sum of shap_means[0..i-1]
154
  def window_sum(start):
155
  end = start + window_size
156
  return csum[end] - csum[start]
157
 
158
  best_start = 0
 
 
159
  # Initialize the best with the first window
160
  best_sum = window_sum(0)
161
  best_avg = best_sum / window_size
 
162
 
163
  for start in range(1, n - window_size + 1):
164
  wsum = window_sum(start)
 
192
  Plots a 1D heatmap of per-base SHAP contributions.
193
  Negative = push toward Non-Human, Positive = push toward Human.
194
  Optionally can show only a subrange (start:end).
195
+
196
+ We adjust layout so the colorbar is well below the x-axis:
197
+ - orientation='horizontal', pad=0.35
198
+ - plt.subplots_adjust(bottom=0.4)
199
  """
200
  if start is not None and end is not None:
201
  shap_means = shap_means[start:end]
 
208
  fig, ax = plt.subplots(figsize=(12, 2))
209
  cax = ax.imshow(heatmap_data, aspect='auto', cmap='RdBu_r')
210
 
211
+ # Place colorbar below and add extra margin
212
+ cbar = plt.colorbar(cax, orientation='horizontal', pad=0.35)
 
213
  cbar.set_label('SHAP Contribution')
214
 
215
  ax.set_yticks([])
216
  ax.set_xlabel('Position in Sequence')
217
  ax.set_title(f"{title}{subtitle}")
218
+
219
+ # Extra bottom margin so colorbar won't overlap x-axis labels
220
+ plt.subplots_adjust(bottom=0.4)
221
 
222
  return fig
223
 
 
280
  with open(file_obj, 'r') as f:
281
  text = f.read()
282
  except Exception as e:
283
+ return (f"Error reading file: {str(e)}", None, None, None, None)
284
  else:
285
+ return ("Please provide a FASTA sequence.", None, None, None, None)
286
 
287
  # Parse FASTA
288
  sequences = parse_fasta(text)
289
  if not sequences:
290
+ return ("No valid FASTA sequences found.", None, None, None, None)
291
 
292
  header, seq = sequences[0]
293
 
 
298
  model.load_state_dict(torch.load('model.pt', map_location=device))
299
  scaler = joblib.load('scaler.pkl')
300
  except Exception as e:
301
+ return (f"Error loading model: {str(e)}", None, None, None, None)
302
 
303
  # Vectorize + scale
304
  freq_vector = sequence_to_kmer_vector(seq)
 
343
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
344
  heatmap_img = fig_to_image(heatmap_fig)
345
 
346
+ # Store data for subregion analysis
 
 
 
 
 
 
 
347
  state_dict = {
348
  "seq": seq,
349
  "shap_means": shap_means
350
  }
351
 
352
+ # We now return 5 items (not 6):
353
+ return (results_text, bar_img, heatmap_img, state_dict, header)
354
 
355
  ###############################################################################
356
  # 8. SUBREGION ANALYSIS (Gradio Step 2)
 
475
  kmer_img = gr.Image(label="Top k-mer SHAP")
476
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap")
477
 
478
+ # State for step 2
479
  seq_state = gr.State()
480
  header_state = gr.State()
481
 
482
+ # analyze_sequence(...) now returns 5 items, so we have 5 outputs.
483
  # 1) results_text
484
  # 2) bar_img
485
  # 3) heatmap_img
486
  # 4) state_dict
487
  # 5) header
 
488
  analyze_btn.click(
489
  analyze_sequence,
490
  inputs=[file_input, top_k, text_input, win_size],
491
+ outputs=[results_box, kmer_img, genome_img, seq_state, header_state]
492
  )
493
 
494
  with gr.Tab("2) Subregion Exploration"):