hiyata commited on
Commit
d01c414
·
verified ·
1 Parent(s): 1b5b7bf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -35
app.py CHANGED
@@ -17,8 +17,6 @@ import pandas as pd
17
  import tempfile
18
  import os
19
  from typing import List, Dict, Tuple, Optional, Any
20
- import io
21
- from io import BytesIO
22
  import seaborn as sns
23
 
24
  ###############################################################################
@@ -55,7 +53,8 @@ def parse_fasta(text):
55
  current_sequence = []
56
  for line in text.strip().split('\n'):
57
  line = line.strip()
58
- if not line: continue
 
59
  if line.startswith('>'):
60
  if current_header:
61
  sequences.append((current_header, ''.join(current_sequence)))
@@ -128,7 +127,8 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
128
 
129
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
130
  n = len(shap_means)
131
- if n == 0: return (0, 0, 0.0)
 
132
  if window_size >= n:
133
  return (0, n, float(np.mean(shap_means)))
134
  csum = np.zeros(n + 1, dtype=np.float32)
@@ -140,9 +140,11 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
140
  wsum = csum[start + window_size] - csum[start]
141
  wavg = wsum / window_size
142
  if mode == "max" and wavg > best_avg:
143
- best_avg = wavg; best_start = start
 
144
  elif mode == "min" and wavg < best_avg:
145
- best_avg = wavg; best_start = start
 
146
  return (best_start, best_start + window_size, float(best_avg))
147
 
148
  ###############################################################################
@@ -201,9 +203,9 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
201
  plt.tight_layout()
202
  return fig
203
 
204
- def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
205
  fig, ax = plt.subplots(figsize=(6, 4))
206
- ax.hist(shap_array, bins=30, color='gray', edgecolor='black')
207
  ax.axvline(0, color='red', linestyle='--', label='0.0')
208
  ax.set_xlabel("SHAP Value")
209
  ax.set_ylabel("Count")
@@ -213,7 +215,8 @@ def plot_shap_histogram(shap_array, title="SHAP Distribution in Region"):
213
  return fig
214
 
215
  def compute_gc_content(sequence):
216
- if not sequence: return 0
 
217
  gc_count = sequence.count('G') + sequence.count('C')
218
  return (gc_count / len(sequence)) * 100.0
219
 
@@ -229,23 +232,24 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
229
  with open(file_obj, 'r') as f:
230
  text = f.read()
231
  except Exception as e:
232
- return (f"Error reading file: {str(e)}", None, None, None, None)
233
  else:
234
- return ("Please provide a FASTA sequence.", None, None, None, None)
235
 
236
  sequences = parse_fasta(text)
237
  if not sequences:
238
- return ("No valid FASTA sequences found.", None, None, None, None)
239
  header, seq = sequences[0]
240
 
241
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
242
  try:
243
- state_dict = torch.load('model.pt', map_location=device, weights_only=True)
 
244
  model = VirusClassifier(256).to(device)
245
  model.load_state_dict(state_dict)
246
  scaler = joblib.load('scaler.pkl')
247
  except Exception as e:
248
- return (f"Error loading model/scaler: {str(e)}", None, None, None, None)
249
 
250
  freq_vector = sequence_to_kmer_vector(seq)
251
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
@@ -280,9 +284,11 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
280
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
281
  heatmap_img = fig_to_image(heatmap_fig)
282
 
 
 
283
  state_dict_out = {"seq": seq, "shap_means": shap_means}
284
 
285
- return (results_text, bar_img, heatmap_img, state_dict_out, header)
286
 
287
  ###############################################################################
288
  # 8. SUBREGION ANALYSIS (Gradio Step 2)
@@ -290,7 +296,7 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
290
 
291
  def analyze_subregion(state, header, region_start, region_end):
292
  if not state or "seq" not in state or "shap_means" not in state:
293
- return ("No sequence data found. Please run Step 1 first.", None, None)
294
  seq = state["seq"]
295
  shap_means = state["shap_means"]
296
  region_start = int(region_start)
@@ -298,7 +304,7 @@ def analyze_subregion(state, header, region_start, region_end):
298
  region_start = max(0, min(region_start, len(seq)))
299
  region_end = max(0, min(region_end, len(seq)))
300
  if region_end <= region_start:
301
- return ("Invalid region range. End must be > Start.", None, None)
302
  region_seq = seq[region_start:region_end]
303
  region_shap = shap_means[region_start:region_end]
304
  gc_percent = compute_gc_content(region_seq)
@@ -324,7 +330,9 @@ def analyze_subregion(state, header, region_start, region_end):
324
  heatmap_img = fig_to_image(heatmap_fig)
325
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
326
  hist_img = fig_to_image(hist_fig)
327
- return (region_info, heatmap_img, hist_img)
 
 
328
 
329
  ###############################################################################
330
  # 9. COMPARISON ANALYSIS FUNCTIONS
@@ -476,12 +484,12 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
476
  # Analyze first sequence
477
  res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500)
478
  if isinstance(res1[0], str) and "Error" in res1[0]:
479
- return (f"Error in sequence 1: {res1[0]}", None, None)
480
 
481
  # Analyze second sequence
482
  res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500)
483
  if isinstance(res2[0], str) and "Error" in res2[0]:
484
- return (f"Error in sequence 2: {res2[0]}", None, None)
485
 
486
  # Extract SHAP values and sequence info
487
  shap1 = res1[3]["shap_means"]
@@ -561,11 +569,12 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
561
  )
562
  hist_img = fig_to_image(hist_fig)
563
 
564
- return comparison_text, heatmap_img, hist_img
 
565
 
566
  except Exception as e:
567
  error_msg = f"Error during sequence comparison: {str(e)}"
568
- return error_msg, None, None
569
 
570
  ###############################################################################
571
  # 11. GENE FEATURE ANALYSIS
@@ -753,13 +762,11 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
753
  # Prepare gene name label
754
  label = str(gene.get('gene_name','?'))
755
 
756
- # If getsize() or textsize() is missing, use getmask(...).size as fallback
757
- # But if your Pillow version supports font.getsize, you can do:
758
- # label_width, label_height = font.getsize(label)
759
  label_mask = font.getmask(label)
760
  label_width, label_height = label_mask.size
761
 
762
- # Alternate label positions above/below line
763
  if idx % 2 == 0:
764
  text_y = line_y - track_height - 15
765
  else:
@@ -821,12 +828,10 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
821
 
822
  return img
823
 
824
-
825
-
826
  def analyze_gene_features(sequence_file: str,
827
- features_file: str,
828
- fasta_text: str = "",
829
- features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
830
  """Analyze SHAP values for each gene feature"""
831
  # First analyze whole sequence
832
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
@@ -980,7 +985,7 @@ with gr.Blocks(css=css) as iface:
980
  **Step 3**: Analyze gene features and their contributions.
981
  **Step 4**: Compare sequences and analyze differences.
982
 
983
- **Color Scale**: Negative SHAP = Blue, Zero = White, Positive = Red.
984
  """)
985
 
986
  with gr.Tab("1) Full-Sequence Analysis"):
@@ -998,6 +1003,7 @@ with gr.Blocks(css=css) as iface:
998
  download_results = gr.File(label="Download Results", visible=False, elem_classes="download-button")
999
  seq_state = gr.State()
1000
  header_state = gr.State()
 
1001
  analyze_btn.click(
1002
  analyze_sequence,
1003
  inputs=[file_input, top_k, text_input, win_size],
@@ -1019,6 +1025,7 @@ with gr.Blocks(css=css) as iface:
1019
  subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
1020
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
1021
  download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
 
1022
  region_btn.click(
1023
  analyze_subregion,
1024
  inputs=[seq_state, header_state, region_start, region_end],
@@ -1065,8 +1072,8 @@ with gr.Blocks(css=css) as iface:
1065
  The sequences will be normalized to the same length for comparison.
1066
 
1067
  **Color Scale**:
1068
- - Red: Sequence 2 is more human-like in this region
1069
- - Blue: Sequence 1 is more human-like in this region
1070
  - White: No substantial difference
1071
  """)
1072
  with gr.Row():
@@ -1082,6 +1089,7 @@ with gr.Blocks(css=css) as iface:
1082
  diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
1083
  diff_hist = gr.Image(label="Distribution of SHAP Differences")
1084
  download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
 
1085
  compare_btn.click(
1086
  analyze_sequence_comparison,
1087
  inputs=[file_input1, file_input2, text_input1, text_input2],
@@ -1110,4 +1118,4 @@ with gr.Blocks(css=css) as iface:
1110
  """)
1111
 
1112
  if __name__ == "__main__":
1113
- iface.launch()
 
17
  import tempfile
18
  import os
19
  from typing import List, Dict, Tuple, Optional, Any
 
 
20
  import seaborn as sns
21
 
22
  ###############################################################################
 
53
  current_sequence = []
54
  for line in text.strip().split('\n'):
55
  line = line.strip()
56
+ if not line:
57
+ continue
58
  if line.startswith('>'):
59
  if current_header:
60
  sequences.append((current_header, ''.join(current_sequence)))
 
127
 
128
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
129
  n = len(shap_means)
130
+ if n == 0:
131
+ return (0, 0, 0.0)
132
  if window_size >= n:
133
  return (0, n, float(np.mean(shap_means)))
134
  csum = np.zeros(n + 1, dtype=np.float32)
 
140
  wsum = csum[start + window_size] - csum[start]
141
  wavg = wsum / window_size
142
  if mode == "max" and wavg > best_avg:
143
+ best_avg = wavg
144
+ best_start = start
145
  elif mode == "min" and wavg < best_avg:
146
+ best_avg = wavg
147
+ best_start = start
148
  return (best_start, best_start + window_size, float(best_avg))
149
 
150
  ###############################################################################
 
203
  plt.tight_layout()
204
  return fig
205
 
206
+ def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30):
207
  fig, ax = plt.subplots(figsize=(6, 4))
208
+ ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
209
  ax.axvline(0, color='red', linestyle='--', label='0.0')
210
  ax.set_xlabel("SHAP Value")
211
  ax.set_ylabel("Count")
 
215
  return fig
216
 
217
  def compute_gc_content(sequence):
218
+ if not sequence:
219
+ return 0
220
  gc_count = sequence.count('G') + sequence.count('C')
221
  return (gc_count / len(sequence)) * 100.0
222
 
 
232
  with open(file_obj, 'r') as f:
233
  text = f.read()
234
  except Exception as e:
235
+ return (f"Error reading file: {str(e)}", None, None, None, None, None)
236
  else:
237
+ return ("Please provide a FASTA sequence.", None, None, None, None, None)
238
 
239
  sequences = parse_fasta(text)
240
  if not sequences:
241
+ return ("No valid FASTA sequences found.", None, None, None, None, None)
242
  header, seq = sequences[0]
243
 
244
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
245
  try:
246
+ # IMPORTANT: adjust how you load your model as needed
247
+ state_dict = torch.load('model.pt', map_location=device)
248
  model = VirusClassifier(256).to(device)
249
  model.load_state_dict(state_dict)
250
  scaler = joblib.load('scaler.pkl')
251
  except Exception as e:
252
+ return (f"Error loading model/scaler: {str(e)}", None, None, None, None, None)
253
 
254
  freq_vector = sequence_to_kmer_vector(seq)
255
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
 
284
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
285
  heatmap_img = fig_to_image(heatmap_fig)
286
 
287
+ # You might want to provide a CSV or other data for the 6th return item
288
+ # Here, we'll simply return None for the file download:
289
  state_dict_out = {"seq": seq, "shap_means": shap_means}
290
 
291
+ return (results_text, bar_img, heatmap_img, state_dict_out, header, None)
292
 
293
  ###############################################################################
294
  # 8. SUBREGION ANALYSIS (Gradio Step 2)
 
296
 
297
  def analyze_subregion(state, header, region_start, region_end):
298
  if not state or "seq" not in state or "shap_means" not in state:
299
+ return ("No sequence data found. Please run Step 1 first.", None, None, None)
300
  seq = state["seq"]
301
  shap_means = state["shap_means"]
302
  region_start = int(region_start)
 
304
  region_start = max(0, min(region_start, len(seq)))
305
  region_end = max(0, min(region_end, len(seq)))
306
  if region_end <= region_start:
307
+ return ("Invalid region range. End must be > Start.", None, None, None)
308
  region_seq = seq[region_start:region_end]
309
  region_shap = shap_means[region_start:region_end]
310
  gc_percent = compute_gc_content(region_seq)
 
330
  heatmap_img = fig_to_image(heatmap_fig)
331
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
332
  hist_img = fig_to_image(hist_fig)
333
+
334
+ # For demonstration, returning None for the file download as well
335
+ return (region_info, heatmap_img, hist_img, None)
336
 
337
  ###############################################################################
338
  # 9. COMPARISON ANALYSIS FUNCTIONS
 
484
  # Analyze first sequence
485
  res1 = analyze_sequence(file1, top_kmers=10, fasta_text=fasta1, window_size=500)
486
  if isinstance(res1[0], str) and "Error" in res1[0]:
487
+ return (f"Error in sequence 1: {res1[0]}", None, None, None)
488
 
489
  # Analyze second sequence
490
  res2 = analyze_sequence(file2, top_kmers=10, fasta_text=fasta2, window_size=500)
491
  if isinstance(res2[0], str) and "Error" in res2[0]:
492
+ return (f"Error in sequence 2: {res2[0]}", None, None, None)
493
 
494
  # Extract SHAP values and sequence info
495
  shap1 = res1[3]["shap_means"]
 
569
  )
570
  hist_img = fig_to_image(hist_fig)
571
 
572
+ # Return 4 outputs (text, image, image, and a file or None for the last)
573
+ return (comparison_text, heatmap_img, hist_img, None)
574
 
575
  except Exception as e:
576
  error_msg = f"Error during sequence comparison: {str(e)}"
577
+ return (error_msg, None, None, None)
578
 
579
  ###############################################################################
580
  # 11. GENE FEATURE ANALYSIS
 
762
  # Prepare gene name label
763
  label = str(gene.get('gene_name','?'))
764
 
765
+ # Fallback for label size
 
 
766
  label_mask = font.getmask(label)
767
  label_width, label_height = label_mask.size
768
 
769
+ # Alternate label positions
770
  if idx % 2 == 0:
771
  text_y = line_y - track_height - 15
772
  else:
 
828
 
829
  return img
830
 
 
 
831
  def analyze_gene_features(sequence_file: str,
832
+ features_file: str,
833
+ fasta_text: str = "",
834
+ features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
835
  """Analyze SHAP values for each gene feature"""
836
  # First analyze whole sequence
837
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
 
985
  **Step 3**: Analyze gene features and their contributions.
986
  **Step 4**: Compare sequences and analyze differences.
987
 
988
+ **Color Scale**: Negative SHAP = Blue, Zero = White, Positive SHAP = Red.
989
  """)
990
 
991
  with gr.Tab("1) Full-Sequence Analysis"):
 
1003
  download_results = gr.File(label="Download Results", visible=False, elem_classes="download-button")
1004
  seq_state = gr.State()
1005
  header_state = gr.State()
1006
+
1007
  analyze_btn.click(
1008
  analyze_sequence,
1009
  inputs=[file_input, top_k, text_input, win_size],
 
1025
  subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
1026
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
1027
  download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
1028
+
1029
  region_btn.click(
1030
  analyze_subregion,
1031
  inputs=[seq_state, header_state, region_start, region_end],
 
1072
  The sequences will be normalized to the same length for comparison.
1073
 
1074
  **Color Scale**:
1075
+ - Red: Sequence 2 more human-like
1076
+ - Blue: Sequence 1 more human-like
1077
  - White: No substantial difference
1078
  """)
1079
  with gr.Row():
 
1089
  diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
1090
  diff_hist = gr.Image(label="Distribution of SHAP Differences")
1091
  download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
1092
+
1093
  compare_btn.click(
1094
  analyze_sequence_comparison,
1095
  inputs=[file_input1, file_input2, text_input1, text_input2],
 
1118
  """)
1119
 
1120
  if __name__ == "__main__":
1121
+ iface.launch()