hiyata commited on
Commit
9308c12
·
verified ·
1 Parent(s): 18efb8a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -331
app.py CHANGED
@@ -67,9 +67,6 @@ def parse_fasta(text):
67
  return sequences
68
 
69
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
70
- """
71
- Convert a sequence into a frequency vector of all possible 4-mer combinations.
72
- """
73
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
74
  kmer_dict = {km: i for i, km in enumerate(kmers)}
75
  vec = np.zeros(len(kmers), dtype=np.float32)
@@ -87,15 +84,11 @@ def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
87
  ###############################################################################
88
 
89
  def calculate_shap_values(model, x_tensor):
90
- """
91
- A simple ablation-based SHAP approximation. Zero out each position
92
- and measure the impact on the 'human' probability.
93
- """
94
  model.eval()
95
  with torch.no_grad():
96
  baseline_output = model(x_tensor)
97
  baseline_probs = torch.softmax(baseline_output, dim=1)
98
- baseline_prob = baseline_probs[0, 1].item() # Probability for 'human'
99
  shap_values = []
100
  x_zeroed = x_tensor.clone()
101
  for i in range(x_tensor.shape[1]):
@@ -113,9 +106,6 @@ def calculate_shap_values(model, x_tensor):
113
  ###############################################################################
114
 
115
  def compute_positionwise_scores(sequence, shap_values, k=4):
116
- """
117
- Distribute each k-mer's SHAP contribution across its k underlying positions.
118
- """
119
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
120
  kmer_dict = {km: i for i, km in enumerate(kmers)}
121
  seq_len = len(sequence)
@@ -136,9 +126,6 @@ def compute_positionwise_scores(sequence, shap_values, k=4):
136
  ###############################################################################
137
 
138
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
139
- """
140
- Use a sliding window to find the subregion with the highest (or lowest) average SHAP.
141
- """
142
  n = len(shap_means)
143
  if n == 0:
144
  return (0, 0, 0.0)
@@ -165,9 +152,6 @@ def find_extreme_subregion(shap_means, window_size=500, mode="max"):
165
  ###############################################################################
166
 
167
  def fig_to_image(fig):
168
- """
169
- Render a Matplotlib figure to a PIL Image.
170
- """
171
  buf = io.BytesIO()
172
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
173
  buf.seek(0)
@@ -176,16 +160,10 @@ def fig_to_image(fig):
176
  return img
177
 
178
  def get_zero_centered_cmap():
179
- """
180
- Create a symmetrical (blue-white-red) colormap around zero.
181
- """
182
  colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
183
  return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
184
 
185
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
186
- """
187
- Plot an inline heatmap for the chosen region (or entire genome if start/end not provided).
188
- """
189
  if start is not None and end is not None:
190
  local_shap = shap_means[start:end]
191
  subtitle = f" (positions {start}-{end})"
@@ -211,9 +189,6 @@ def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, e
211
  return fig
212
 
213
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
214
- """
215
- Show bar chart of top k-mers by absolute SHAP value.
216
- """
217
  plt.rcParams.update({'font.size': 10})
218
  fig = plt.figure(figsize=(10, 5))
219
  indices = np.argsort(np.abs(shap_values))[-top_k:]
@@ -229,9 +204,6 @@ def create_importance_bar_plot(shap_values, kmers, top_k=10):
229
  return fig
230
 
231
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30):
232
- """
233
- Plot a histogram of SHAP values in some region.
234
- """
235
  fig, ax = plt.subplots(figsize=(6, 4))
236
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
237
  ax.axvline(0, color='red', linestyle='--', label='0.0')
@@ -243,11 +215,8 @@ def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bin
243
  return fig
244
 
245
  def compute_gc_content(sequence):
246
- """
247
- Compute GC content (%) for a given sequence.
248
- """
249
  if not sequence:
250
- return 0.0
251
  gc_count = sequence.count('G') + sequence.count('C')
252
  return (gc_count / len(sequence)) * 100.0
253
 
@@ -256,11 +225,6 @@ def compute_gc_content(sequence):
256
  ###############################################################################
257
 
258
  def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
259
- """
260
- Perform the main classification, SHAP analysis, and extreme subregion detection
261
- for a single sequence.
262
- """
263
- # 1) Read input
264
  if fasta_text.strip():
265
  text = fasta_text.strip()
266
  elif file_obj is not None:
@@ -272,15 +236,14 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
272
  else:
273
  return ("Please provide a FASTA sequence.", None, None, None, None, None)
274
 
275
- # 2) Parse FASTA
276
  sequences = parse_fasta(text)
277
  if not sequences:
278
  return ("No valid FASTA sequences found.", None, None, None, None, None)
279
  header, seq = sequences[0]
280
 
281
- # 3) Load model, scaler, and run inference
282
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
283
  try:
 
284
  state_dict = torch.load('model.pt', map_location=device)
285
  model = VirusClassifier(256).to(device)
286
  model.load_state_dict(state_dict)
@@ -297,12 +260,10 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
297
  classification = "Human" if prob_human > 0.5 else "Non-human"
298
  confidence = max(prob_human, prob_nonhuman)
299
 
300
- # 4) Per-base SHAP & subregion detection
301
  shap_means = compute_positionwise_scores(seq, shap_values, k=4)
302
  max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max")
303
  min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min")
304
 
305
- # 5) Prepare result text
306
  results_text = (
307
  f"Sequence: {header}\n"
308
  f"Length: {len(seq):,} bases\n"
@@ -316,7 +277,6 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
316
  f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
317
  )
318
 
319
- # 6) Create bar & heatmap figures
320
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
321
  bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
322
  bar_img = fig_to_image(bar_fig)
@@ -324,10 +284,10 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
324
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
325
  heatmap_img = fig_to_image(heatmap_fig)
326
 
327
- # 7) Build the "state" dictionary so we can do subregion analysis
 
328
  state_dict_out = {"seq": seq, "shap_means": shap_means}
329
 
330
- # Return 6 items to match your Gradio output
331
  return (results_text, bar_img, heatmap_img, state_dict_out, header, None)
332
 
333
  ###############################################################################
@@ -335,9 +295,6 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
335
  ###############################################################################
336
 
337
  def analyze_subregion(state, header, region_start, region_end):
338
- """
339
- Examine a subregion’s SHAP distribution, GC content, etc.
340
- """
341
  if not state or "seq" not in state or "shap_means" not in state:
342
  return ("No sequence data found. Please run Step 1 first.", None, None, None)
343
  seq = state["seq"]
@@ -348,22 +305,18 @@ def analyze_subregion(state, header, region_start, region_end):
348
  region_end = max(0, min(region_end, len(seq)))
349
  if region_end <= region_start:
350
  return ("Invalid region range. End must be > Start.", None, None, None)
351
-
352
  region_seq = seq[region_start:region_end]
353
  region_shap = shap_means[region_start:region_end]
354
-
355
  gc_percent = compute_gc_content(region_seq)
356
  avg_shap = float(np.mean(region_shap))
357
  positive_fraction = np.mean(region_shap > 0)
358
  negative_fraction = np.mean(region_shap < 0)
359
-
360
  if avg_shap > 0.05:
361
  region_classification = "Likely pushing toward human"
362
  elif avg_shap < -0.05:
363
  region_classification = "Likely pushing toward non-human"
364
  else:
365
  region_classification = "Near neutral (no strong push)"
366
-
367
  region_info = (
368
  f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
369
  f"Region length: {len(region_seq)} bases\n"
@@ -373,29 +326,30 @@ def analyze_subregion(state, header, region_start, region_end):
373
  f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
374
  f"Subregion interpretation: {region_classification}\n"
375
  )
376
-
377
  heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end)
378
  heatmap_img = fig_to_image(heatmap_fig)
379
-
380
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
381
  hist_img = fig_to_image(hist_fig)
382
-
383
- # Return 4 items to match your Gradio output
384
  return (region_info, heatmap_img, hist_img, None)
385
 
386
  ###############################################################################
387
- # 9. COMPARISON ANALYSIS FUNCTIONS (Step 4)
388
  ###############################################################################
389
 
 
 
 
 
 
390
  def compute_shap_difference(shap1_norm, shap2_norm):
391
- """
392
- Compute the SHAP difference (Seq2 - Seq1).
393
- """
394
  return shap2_norm - shap1_norm
395
 
396
  def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
397
  """
398
- Plot a 1D heatmap of differences using relative positions 0-100%.
399
  """
400
  heatmap_data = shap_diff.reshape(1, -1)
401
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
@@ -424,7 +378,7 @@ def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
424
 
425
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
426
  """
427
- Plot a histogram of SHAP values with optional # of bins.
428
  """
429
  fig, ax = plt.subplots(figsize=(6, 4))
430
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
@@ -438,16 +392,18 @@ def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
438
 
439
  def calculate_adaptive_parameters(len1, len2):
440
  """
441
- Choose smoothing & interpolation parameters automatically based on length difference.
 
442
  """
443
  length_diff = abs(len1 - len2)
444
  max_length = max(len1, len2)
445
  min_length = min(len1, len2)
446
  length_ratio = min_length / max_length
447
 
448
- # Base number of points
449
  base_points = min(2000, max(500, max_length // 100))
450
 
 
451
  if length_diff < 500:
452
  resolution_factor = 2.0
453
  num_points = min(3000, base_points * 2)
@@ -465,22 +421,29 @@ def calculate_adaptive_parameters(len1, len2):
465
  num_points = max(500, base_points // 2)
466
  smooth_window = max(100, length_diff // 500)
467
 
 
468
  smooth_window = int(smooth_window * (1 + (1 - length_ratio)))
 
469
  return int(num_points), int(smooth_window), resolution_factor
470
 
471
  def sliding_window_smooth(values, window_size=50):
472
  """
473
- A custom smoothing approach, including exponential decay at edges.
474
  """
475
  if window_size < 3:
476
  return values
 
 
477
  window = np.ones(window_size)
478
  decay = np.exp(-np.linspace(0, 3, window_size // 2))
479
  window[:window_size // 2] = decay
480
  window[-(window_size // 2):] = decay[::-1]
481
  window = window / window.sum()
482
 
 
483
  smoothed = np.convolve(values, window, mode='valid')
 
 
484
  pad_size = len(values) - len(smoothed)
485
  pad_left = pad_size // 2
486
  pad_right = pad_size - pad_left
@@ -494,13 +457,16 @@ def sliding_window_smooth(values, window_size=50):
494
 
495
  def normalize_shap_lengths(shap1, shap2):
496
  """
497
- Smooth, interpolate, and return arrays of the same length for direct comparison.
498
  """
 
499
  num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
500
 
 
501
  shap1_smooth = sliding_window_smooth(shap1, smooth_window)
502
  shap2_smooth = sliding_window_smooth(shap2, smooth_window)
503
 
 
504
  x1 = np.linspace(0, 1, len(shap1_smooth))
505
  x2 = np.linspace(0, 1, len(shap2_smooth))
506
  x_norm = np.linspace(0, 1, num_points)
@@ -512,8 +478,7 @@ def normalize_shap_lengths(shap1, shap2):
512
 
513
  def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
514
  """
515
- Compare two sequences using the previously defined analysis pipeline
516
- and produce difference visualizations & stats.
517
  """
518
  try:
519
  # Analyze first sequence
@@ -526,23 +491,26 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
526
  if isinstance(res2[0], str) and "Error" in res2[0]:
527
  return (f"Error in sequence 2: {res2[0]}", None, None, None)
528
 
 
529
  shap1 = res1[3]["shap_means"]
530
  shap2 = res2[3]["shap_means"]
531
 
 
532
  len1, len2 = len(shap1), len(shap2)
533
  length_diff = abs(len1 - len2)
534
  length_ratio = min(len1, len2) / max(len1, len2)
535
-
536
- # Normalize both to the same length
537
  shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2)
538
  shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
539
 
540
- # Compute stats
541
  base_threshold = 0.05
542
  adaptive_threshold = base_threshold * (1 + (1 - length_ratio))
543
  if length_diff > 50000:
544
  adaptive_threshold *= 1.5
545
 
 
546
  avg_diff = np.mean(shap_diff)
547
  std_diff = np.std(shap_diff)
548
  max_diff = np.max(shap_diff)
@@ -550,7 +518,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
550
  substantial_diffs = np.abs(shap_diff) > adaptive_threshold
551
  frac_different = np.mean(substantial_diffs)
552
 
553
- # Extract classification from text
554
  try:
555
  classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
556
  classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
@@ -558,6 +526,7 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
558
  classification1 = "Unknown"
559
  classification2 = "Unknown"
560
 
 
561
  comparison_text = (
562
  "Sequence Comparison Results:\n"
563
  f"Sequence 1: {res1[4]}\n"
@@ -584,12 +553,14 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
584
  "- White regions: Similar between sequences"
585
  )
586
 
 
587
  heatmap_fig = plot_comparative_heatmap(
588
  shap_diff,
589
  title=f"SHAP Difference Heatmap (window: {smooth_window})"
590
  )
591
  heatmap_img = fig_to_image(heatmap_fig)
592
 
 
593
  num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
594
  hist_fig = plot_shap_histogram(
595
  shap_diff,
@@ -598,62 +569,31 @@ def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
598
  )
599
  hist_img = fig_to_image(hist_fig)
600
 
 
601
  return (comparison_text, heatmap_img, hist_img, None)
602
 
603
  except Exception as e:
604
  error_msg = f"Error during sequence comparison: {str(e)}"
605
  return (error_msg, None, None, None)
606
 
607
- ###############################################################################
608
- # 10. ADDITIONAL / ADVANCED VISUALIZATIONS & STATISTICS
609
- ###############################################################################
610
-
611
- def n50_length(sequence):
612
- """
613
- Calculate the N50 for a single continuous sequence (for demonstration).
614
- For a single sequence, N50 is typically the length if it's just one piece,
615
- but let's do a simplistic example.
616
- """
617
- # If you had contigs, you'd do a sorted list, cumulative sums, etc.
618
- # We'll do a trivial approach here:
619
- return len(sequence) # Because we have only one contiguous region
620
-
621
- def sequence_complexity(sequence):
622
- """
623
- Compute a simple measure of 'sequence complexity'.
624
- Here, we define complexity as the Shannon entropy over the nucleotides.
625
- """
626
- from math import log2
627
- length = len(sequence)
628
- if length == 0:
629
- return 0.0
630
- freq = {}
631
- for base in sequence:
632
- freq[base] = freq.get(base, 0) + 1
633
- complexity = 0.0
634
- for base, count in freq.items():
635
- p = count / length
636
- complexity -= p * log2(p)
637
- return complexity
638
-
639
- def advanced_gene_statistics(gene_shap: np.ndarray, gene_seq: str) -> Dict[str, float]:
640
- """
641
- Additional stats: N50, complexity, etc.
642
- """
643
- stats = {}
644
- stats['n50'] = len(gene_seq) # trivial for a single gene region
645
- stats['entropy'] = sequence_complexity(gene_seq)
646
- stats['avg_shap'] = float(np.mean(gene_shap))
647
- stats['max_shap'] = float(np.max(gene_shap)) if len(gene_shap) else 0.0
648
- stats['min_shap'] = float(np.min(gene_shap)) if len(gene_shap) else 0.0
649
- return stats
650
-
651
  ###############################################################################
652
  # 11. GENE FEATURE ANALYSIS
653
  ###############################################################################
654
 
 
 
 
 
 
 
 
 
 
 
 
 
655
  def parse_gene_features(text: str) -> List[Dict[str, Any]]:
656
- """Parse gene features from text file in a FASTA-like format."""
657
  genes = []
658
  current_header = None
659
  current_sequence = []
@@ -662,6 +602,7 @@ def parse_gene_features(text: str) -> List[Dict[str, Any]]:
662
  line = line.strip()
663
  if not line:
664
  continue
 
665
  if line.startswith('>'):
666
  if current_header:
667
  genes.append({
@@ -673,29 +614,36 @@ def parse_gene_features(text: str) -> List[Dict[str, Any]]:
673
  current_sequence = []
674
  else:
675
  current_sequence.append(line.upper())
 
676
  if current_header:
677
  genes.append({
678
  'header': current_header,
679
  'sequence': ''.join(current_sequence),
680
  'metadata': parse_gene_metadata(current_header)
681
  })
 
682
  return genes
683
 
684
  def parse_gene_metadata(header: str) -> Dict[str, str]:
685
- """Extract metadata from gene header line."""
686
  metadata = {}
687
  parts = header.split()
 
688
  for part in parts:
689
  if '[' in part and ']' in part:
690
  key_value = part[1:-1].split('=', 1)
691
  if len(key_value) == 2:
692
  metadata[key_value[0]] = key_value[1]
 
693
  return metadata
694
 
695
  def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
696
- """Parse gene location string, handling forward and complement strands."""
697
  try:
 
698
  clean_loc = location_str.replace('complement(', '').replace(')', '')
 
 
699
  if '..' in clean_loc:
700
  start, end = map(int, clean_loc.split('..'))
701
  return start, end
@@ -706,41 +654,48 @@ def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
706
  return None, None
707
 
708
  def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
709
- """Basic statistical measures for gene SHAP values."""
710
  return {
711
- 'avg_shap': float(np.mean(gene_shap)) if len(gene_shap) else 0.0,
712
- 'median_shap': float(np.median(gene_shap)) if len(gene_shap) else 0.0,
713
- 'std_shap': float(np.std(gene_shap)) if len(gene_shap) else 0.0,
714
- 'max_shap': float(np.max(gene_shap)) if len(gene_shap) else 0.0,
715
- 'min_shap': float(np.min(gene_shap)) if len(gene_shap) else 0.0,
716
- 'pos_fraction': float(np.mean(gene_shap > 0)) if len(gene_shap) else 0.0
717
  }
718
 
719
  def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
720
  """
721
- A quick PIL-based diagram to show genes along the genome.
722
- Color intensity = magnitude of SHAP. Red/Blue = sign of SHAP.
723
  """
 
 
 
724
  if not gene_results or genome_length <= 0:
725
  img = Image.new('RGB', (800, 100), color='white')
726
  draw = ImageDraw.Draw(img)
727
  draw.text((10, 40), "Error: Invalid input data", fill='black')
728
  return img
729
-
 
730
  for gene in gene_results:
731
  gene['start'] = max(0, int(gene['start']))
732
  gene['end'] = min(genome_length, int(gene['end']))
733
  if gene['start'] >= gene['end']:
734
- print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}")
735
 
 
736
  width = 1500
737
  height = 600
738
  margin = 50
739
  track_height = 40
740
 
 
741
  img = Image.new('RGB', (width, height), 'white')
742
  draw = ImageDraw.Draw(img)
743
 
 
744
  try:
745
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
746
  title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
@@ -748,16 +703,24 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
748
  font = ImageFont.load_default()
749
  title_font = ImageFont.load_default()
750
 
751
- draw.text((margin, margin // 2), "Genome SHAP Analysis (Simple)", fill='black', font=title_font or font)
 
752
 
 
753
  line_y = height // 2
754
  draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2)
755
 
 
756
  scale = float(width - 2 * margin) / float(genome_length)
757
 
758
- # Scale markers
759
  num_ticks = 10
760
- step = max(1, genome_length // num_ticks)
 
 
 
 
 
761
  for i in range(0, genome_length + 1, step):
762
  x_coord = margin + i * scale
763
  draw.line([
@@ -766,33 +729,50 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
766
  ], fill='black', width=1)
767
  draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
768
 
 
769
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
 
 
770
  for idx, gene in enumerate(sorted_genes):
 
771
  start_x = margin + int(gene['start'] * scale)
772
  end_x = margin + int(gene['end'] * scale)
 
 
773
  avg_shap = gene['avg_shap']
 
 
 
774
  intensity = int(abs(avg_shap) * 500)
775
- intensity = max(50, min(255, intensity))
776
 
777
  if avg_shap > 0:
778
- color = (255, 255 - intensity, 255 - intensity) # Redish
 
779
  else:
780
- color = (255 - intensity, 255 - intensity, 255) # Blueish
 
781
 
 
782
  draw.rectangle([
783
  (int(start_x), int(line_y - track_height // 2)),
784
  (int(end_x), int(line_y + track_height // 2))
785
  ], fill=color, outline='black')
786
 
 
787
  label = str(gene.get('gene_name','?'))
 
 
788
  label_mask = font.getmask(label)
789
  label_width, label_height = label_mask.size
790
 
 
791
  if idx % 2 == 0:
792
  text_y = line_y - track_height - 15
793
  else:
794
  text_y = line_y + track_height + 5
795
 
 
796
  gene_width = end_x - start_x
797
  if gene_width > label_width:
798
  text_x = start_x + (gene_width - label_width) // 2
@@ -804,113 +784,64 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
804
  rotated_img = txt_img.rotate(90, expand=True)
805
  img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
806
 
807
- return img
808
-
809
- def create_advanced_genome_diagram(gene_results: List[Dict[str, Any]],
810
- genome_length: int,
811
- shap_means: np.ndarray,
812
- diagram_title: str = "Advanced Genome Diagram") -> Image.Image:
813
- """
814
- An advanced genome diagram using Biopython's GenomeDiagram.
815
- We'll create tracks for genes and a 'SHAP line plot' track.
816
- """
817
- if not gene_results or genome_length <= 0 or len(shap_means) == 0:
818
- # Fallback if data is invalid
819
- img = Image.new('RGB', (800, 100), color='white')
820
- d = ImageDraw.Draw(img)
821
- d.text((10, 40), "Error: Not enough data for advanced diagram", fill='black')
822
- return img
823
-
824
- diagram = GenomeDiagram.Diagram(diagram_title)
825
- gene_track = diagram.new_track(1, name="Genes", greytrack=False, height=0.5)
826
- gene_set = gene_track.new_set()
827
-
828
- # Add each gene as a feature
829
- for gene in gene_results:
830
- start = max(0, int(gene['start']))
831
- end = min(genome_length, int(gene['end']))
832
- avg_shap = gene['avg_shap']
833
- # Color scale: negative = blue, positive = red
834
- intensity = abs(avg_shap) * 500
835
- intensity = max(50, min(255, intensity))
836
- if avg_shap >= 0:
837
- color_hex = colors.Color(1.0, 1.0 - intensity/255.0, 1.0 - intensity/255.0)
838
- else:
839
- color_hex = colors.Color(1.0 - intensity/255.0, 1.0 - intensity/255.0, 1.0)
840
-
841
- feature = SeqFeature(FeatureLocation(start, end), strand=1)
842
- gene_set.add_feature(
843
- feature,
844
- color=color_hex,
845
- label=True,
846
- name=str(gene.get('gene_name','?')),
847
- label_size=8,
848
- label_color=colors.black
849
- )
850
-
851
- # Add a track for the SHAP line
852
- shap_track = diagram.new_track(2, name="SHAP Score", greytrack=False, height=0.3)
853
- shap_set = shap_track.new_set("graph")
854
- # We'll plot the entire shap_means array.
855
- # X coords = [0..genome_length], Y coords = shap_means
856
- # We'll keep negative values below baseline, positive above.
857
-
858
- # Normalizing for visualization
859
- max_abs = max(abs(shap_means.min()), abs(shap_means.max()))
860
- if max_abs == 0:
861
- scaled_shap = [0]*len(shap_means)
862
- else:
863
- scaled_shap = (shap_means / max_abs * 50).tolist() # scale to +/- 50
864
 
865
- shap_set.add_graph(
866
- data=scaled_shap,
867
- name="shap_line",
868
- style="line",
869
- color=colors.darkgreen,
870
- altcolor=colors.red,
871
- linewidth=1
872
- )
873
-
874
- # Draw to a temporary file
875
- with tempfile.NamedTemporaryFile(suffix=".pdf", delete=False) as tmpf:
876
- diagram.draw(format="linear", pagesize='A3', fragments=1, start=0, end=genome_length)
877
- diagram.write(tmpf.name, "PDF")
878
-
879
- # Convert PDF to a PIL image (requires poppler or similar).
880
- # If you do not have poppler, you can skip PDF -> image or use Cairo.
881
- try:
882
- import pdf2image
883
- pages = pdf2image.convert_from_path(tmpf.name, dpi=100)
884
- img = pages[0] if pages else Image.new('RGB', (800, 100), color='white')
885
- except ImportError:
886
- img = Image.new('RGB', (800, 100), color='white')
887
- d = ImageDraw.Draw(img)
888
- d.text((10, 40), "pdf2image not installed, can't show advanced diagram as image.", fill='black')
889
-
890
- # Cleanup
891
- os.remove(tmpf.name)
892
  return img
893
 
894
  def analyze_gene_features(sequence_file: str,
895
  features_file: str,
896
  fasta_text: str = "",
897
- features_text: str = "",
898
- diagram_mode: str = "advanced"
899
- ) -> Tuple[str, Optional[str], Optional[Image.Image]]:
900
- """
901
- Analyze each gene in the features file, compute gene-level SHAP stats,
902
- produce tabular output, and create an optional genome diagram.
903
- """
904
- # 1) Analyze the entire sequence with the top-level function
905
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
906
  if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
907
  return f"Error in sequence analysis: {sequence_results[0]}", None, None
908
 
909
- seq = sequence_results[3]["seq"]
910
  shap_means = sequence_results[3]["shap_means"]
911
- genome_length = len(seq)
912
-
913
- # 2) Read gene features
914
  try:
915
  if features_text.strip():
916
  genes = parse_gene_features(features_text)
@@ -919,100 +850,98 @@ def analyze_gene_features(sequence_file: str,
919
  genes = parse_gene_features(f.read())
920
  except Exception as e:
921
  return f"Error reading features file: {str(e)}", None, None
922
-
 
923
  gene_results = []
924
  for gene in genes:
925
- location = gene['metadata'].get('location', '')
926
- if not location:
927
- continue
928
- start, end = parse_location(location)
929
- if start is None or end is None or start >= end or end > genome_length:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930
  continue
931
- gene_shap = shap_means[start:end]
932
- basic_stats = compute_gene_statistics(gene_shap)
933
- # Additional stats
934
- gene_seq = seq[start:end]
935
- adv_stats = advanced_gene_statistics(gene_shap, gene_seq)
936
-
937
- # Merge basic + advanced stats
938
- all_stats = {**basic_stats, **adv_stats}
939
-
940
- classification = 'Human' if basic_stats['avg_shap'] > 0 else 'Non-human'
941
- locus_tag = gene['metadata'].get('locus_tag', '')
942
- gene_name = gene['metadata'].get('gene', 'Unknown')
943
-
944
- gene_dict = {
945
- 'gene_name': gene_name,
946
- 'location': location,
947
- 'start': start,
948
- 'end': end,
949
- 'locus_tag': locus_tag,
950
- 'avg_shap': all_stats['avg_shap'],
951
- 'median_shap': basic_stats['median_shap'],
952
- 'std_shap': basic_stats['std_shap'],
953
- 'max_shap': basic_stats['max_shap'],
954
- 'min_shap': basic_stats['min_shap'],
955
- 'pos_fraction': basic_stats['pos_fraction'],
956
- 'n50': all_stats['n50'],
957
- 'entropy': all_stats['entropy'],
958
- 'classification': classification,
959
- 'confidence': abs(all_stats['avg_shap'])
960
- }
961
- gene_results.append(gene_dict)
962
-
963
  if not gene_results:
964
  return "No valid genes could be processed", None, None
965
-
966
- # 3) Summaries
967
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
 
 
968
  results_text = "Gene Analysis Results:\n\n"
969
  results_text += f"Total genes analyzed: {len(gene_results)}\n"
970
- num_human = sum(1 for g in gene_results if g['classification'] == 'Human')
971
- results_text += f"Human-like genes: {num_human}\n"
972
- results_text += f"Non-human-like genes: {len(gene_results) - num_human}\n\n"
973
 
974
- results_text += "Top 10 most distinctive genes (by avg SHAP magnitude):\n"
975
  for gene in sorted_genes[:10]:
976
  results_text += (
977
  f"Gene: {gene['gene_name']}\n"
978
  f"Location: {gene['location']}\n"
979
  f"Classification: {gene['classification']} "
980
  f"(confidence: {gene['confidence']:.4f})\n"
981
- f"Average SHAP: {gene['avg_shap']:.4f}\n"
982
- f"N50: {gene['n50']}, Entropy: {gene['entropy']:.3f}\n\n"
983
  )
984
-
985
- # 4) Make CSV
986
- csv_content = "gene_name,location,start,end,locus_tag,avg_shap,median_shap,std_shap,"
987
- csv_content += "max_shap,min_shap,pos_fraction,n50,entropy,classification,confidence\n"
988
- for g in gene_results:
 
989
  csv_content += (
990
- f"{g['gene_name']},{g['location']},{g['start']},{g['end']},{g['locus_tag']},"
991
- f"{g['avg_shap']:.4f},{g['median_shap']:.4f},{g['std_shap']:.4f},"
992
- f"{g['max_shap']:.4f},{g['min_shap']:.4f},{g['pos_fraction']:.4f},"
993
- f"{g['n50']},{g['entropy']:.4f},{g['classification']},{g['confidence']:.4f}\n"
994
  )
 
 
995
  try:
996
  temp_dir = tempfile.gettempdir()
997
  temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv")
 
998
  with open(temp_path, 'w') as f:
999
  f.write(csv_content)
1000
  except Exception as e:
1001
  print(f"Error saving CSV: {str(e)}")
1002
  temp_path = None
1003
-
1004
- # 5) Create diagram
1005
  try:
1006
- if diagram_mode == "advanced":
1007
- diagram_img = create_advanced_genome_diagram(gene_results, genome_length, shap_means)
1008
- else:
1009
- diagram_img = create_simple_genome_diagram(gene_results, genome_length)
1010
  except Exception as e:
1011
  print(f"Error creating visualization: {str(e)}")
 
1012
  diagram_img = Image.new('RGB', (800, 100), color='white')
1013
  draw = ImageDraw.Draw(diagram_img)
1014
  draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black')
1015
-
1016
  return results_text, temp_path, diagram_img
1017
 
1018
  ###############################################################################
@@ -1020,14 +949,13 @@ def analyze_gene_features(sequence_file: str,
1020
  ###############################################################################
1021
 
1022
  def prepare_csv_download(data, filename="analysis_results.csv"):
1023
- """
1024
- Convert data to CSV for Gradio download button.
1025
- """
1026
  if isinstance(data, str):
1027
  return data.encode(), filename
1028
  elif isinstance(data, (list, dict)):
1029
  import csv
1030
  from io import StringIO
 
1031
  output = StringIO()
1032
  writer = csv.DictWriter(output, fieldnames=data[0].keys())
1033
  writer.writeheader()
@@ -1051,22 +979,22 @@ css = """
1051
 
1052
  with gr.Blocks(css=css) as iface:
1053
  gr.Markdown("""
1054
- # Virus Host Classifier + Extended Genome Visualization
1055
- **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme subregions.
1056
- **Step 2**: Explore subregions (local SHAP, GC content, histogram).
1057
- **Step 3**: Analyze gene features (per-gene SHAP, advanced stats, improved diagrams).
1058
- **Step 4**: Compare sequences for SHAP differences.
1059
-
1060
- **Color Scale**: Negative SHAP = Blue, 0 = White, Positive = Red.
1061
  """)
1062
 
1063
  with gr.Tab("1) Full-Sequence Analysis"):
1064
  with gr.Row():
1065
  with gr.Column(scale=1):
1066
  file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1067
- text_input = gr.Textbox(label="Or paste FASTA", placeholder=">name\nACGT...", lines=5)
1068
  top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
1069
- win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Subregion Window Size")
1070
  analyze_btn = gr.Button("Analyze Sequence", variant="primary")
1071
  with gr.Column(scale=2):
1072
  results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
@@ -1085,7 +1013,8 @@ with gr.Blocks(css=css) as iface:
1085
  with gr.Tab("2) Subregion Exploration"):
1086
  gr.Markdown("""
1087
  **Subregion Analysis**
1088
- View SHAP signals, GC content, etc. for a specific region.
 
1089
  """)
1090
  with gr.Row():
1091
  region_start = gr.Number(label="Region Start", value=0)
@@ -1095,7 +1024,7 @@ with gr.Blocks(css=css) as iface:
1095
  with gr.Row():
1096
  subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
1097
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
1098
- download_subregion = gr.File(label="Download Subregion", visible=False, elem_classes="download-button")
1099
 
1100
  region_btn.click(
1101
  analyze_subregion,
@@ -1106,48 +1035,60 @@ with gr.Blocks(css=css) as iface:
1106
  with gr.Tab("3) Gene Features Analysis"):
1107
  gr.Markdown("""
1108
  **Analyze Gene Features**
1109
- - Upload a FASTA file and a gene features file.
1110
- - See per-gene SHAP, classification, N50, entropy, etc.
1111
- - Choose a diagram mode (simple or advanced).
 
 
 
 
 
 
 
1112
  """)
1113
  with gr.Row():
1114
  with gr.Column(scale=1):
1115
- gene_fasta_file = gr.File(label="FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1116
- gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", lines=5)
1117
  with gr.Column(scale=1):
1118
- features_file = gr.File(label="Gene features file", file_types=[".txt"], type="filepath")
1119
- features_text = gr.Textbox(label="Or paste gene features", lines=5)
1120
- diagram_mode = gr.Radio(choices=["simple", "advanced"], value="advanced", label="Diagram Mode")
1121
  analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary")
1122
  gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False)
1123
- gene_diagram = gr.Image(label="Genome Diagram")
1124
  download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True)
1125
 
1126
  analyze_genes_btn.click(
1127
  analyze_gene_features,
1128
- inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text, diagram_mode],
1129
  outputs=[gene_results, download_gene_results, gene_diagram]
1130
  )
1131
 
1132
  with gr.Tab("4) Comparative Analysis"):
1133
  gr.Markdown("""
1134
  **Compare Two Sequences**
1135
- - Upload or paste two FASTA sequences.
1136
- - We'll compare SHAP patterns (normalized for different lengths).
 
 
 
 
 
1137
  """)
1138
  with gr.Row():
1139
  with gr.Column(scale=1):
1140
- file_input1 = gr.File(label="1st FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1141
- text_input1 = gr.Textbox(label="Or paste 1st FASTA", lines=5)
1142
  with gr.Column(scale=1):
1143
- file_input2 = gr.File(label="2nd FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1144
- text_input2 = gr.Textbox(label="Or paste 2nd FASTA", lines=5)
1145
  compare_btn = gr.Button("Compare Sequences", variant="primary")
1146
  comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
1147
  with gr.Row():
1148
  diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
1149
  diff_hist = gr.Image(label="Distribution of SHAP Differences")
1150
- download_comparison = gr.File(label="Download Comparison", visible=False, elem_classes="download-button")
1151
 
1152
  compare_btn.click(
1153
  analyze_sequence_comparison,
@@ -1156,12 +1097,25 @@ with gr.Blocks(css=css) as iface:
1156
  )
1157
 
1158
  gr.Markdown("""
1159
- ### Notes & Features
1160
- - **Advanced Genome Diagram** uses Biopython’s `GenomeDiagram` (requires `pdf2image` if you want it as an image).
1161
- - **Additional Stats**: N50, Shannon entropy, etc.
1162
- - **Auto-scaling** for comparative analysis with adaptive smoothing.
1163
- - **Data Export**: Download CSV of analysis results.
 
 
 
 
 
 
 
 
 
 
 
 
 
1164
  """)
1165
-
1166
  if __name__ == "__main__":
1167
  iface.launch()
 
67
  return sequences
68
 
69
  def sequence_to_kmer_vector(sequence: str, k: int = 4) -> np.ndarray:
 
 
 
70
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
71
  kmer_dict = {km: i for i, km in enumerate(kmers)}
72
  vec = np.zeros(len(kmers), dtype=np.float32)
 
84
  ###############################################################################
85
 
86
  def calculate_shap_values(model, x_tensor):
 
 
 
 
87
  model.eval()
88
  with torch.no_grad():
89
  baseline_output = model(x_tensor)
90
  baseline_probs = torch.softmax(baseline_output, dim=1)
91
+ baseline_prob = baseline_probs[0, 1].item() # Prob of 'human'
92
  shap_values = []
93
  x_zeroed = x_tensor.clone()
94
  for i in range(x_tensor.shape[1]):
 
106
  ###############################################################################
107
 
108
  def compute_positionwise_scores(sequence, shap_values, k=4):
 
 
 
109
  kmers = [''.join(p) for p in product("ACGT", repeat=k)]
110
  kmer_dict = {km: i for i, km in enumerate(kmers)}
111
  seq_len = len(sequence)
 
126
  ###############################################################################
127
 
128
  def find_extreme_subregion(shap_means, window_size=500, mode="max"):
 
 
 
129
  n = len(shap_means)
130
  if n == 0:
131
  return (0, 0, 0.0)
 
152
  ###############################################################################
153
 
154
  def fig_to_image(fig):
 
 
 
155
  buf = io.BytesIO()
156
  fig.savefig(buf, format='png', bbox_inches='tight', dpi=150)
157
  buf.seek(0)
 
160
  return img
161
 
162
  def get_zero_centered_cmap():
 
 
 
163
  colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
164
  return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
165
 
166
  def plot_linear_heatmap(shap_means, title="Per-base SHAP Heatmap", start=None, end=None):
 
 
 
167
  if start is not None and end is not None:
168
  local_shap = shap_means[start:end]
169
  subtitle = f" (positions {start}-{end})"
 
189
  return fig
190
 
191
  def create_importance_bar_plot(shap_values, kmers, top_k=10):
 
 
 
192
  plt.rcParams.update({'font.size': 10})
193
  fig = plt.figure(figsize=(10, 5))
194
  indices = np.argsort(np.abs(shap_values))[-top_k:]
 
204
  return fig
205
 
206
  def plot_shap_histogram(shap_array, title="SHAP Distribution in Region", num_bins=30):
 
 
 
207
  fig, ax = plt.subplots(figsize=(6, 4))
208
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black')
209
  ax.axvline(0, color='red', linestyle='--', label='0.0')
 
215
  return fig
216
 
217
  def compute_gc_content(sequence):
 
 
 
218
  if not sequence:
219
+ return 0
220
  gc_count = sequence.count('G') + sequence.count('C')
221
  return (gc_count / len(sequence)) * 100.0
222
 
 
225
  ###############################################################################
226
 
227
  def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
 
 
 
 
 
228
  if fasta_text.strip():
229
  text = fasta_text.strip()
230
  elif file_obj is not None:
 
236
  else:
237
  return ("Please provide a FASTA sequence.", None, None, None, None, None)
238
 
 
239
  sequences = parse_fasta(text)
240
  if not sequences:
241
  return ("No valid FASTA sequences found.", None, None, None, None, None)
242
  header, seq = sequences[0]
243
 
 
244
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
245
  try:
246
+ # IMPORTANT: adjust how you load your model as needed
247
  state_dict = torch.load('model.pt', map_location=device)
248
  model = VirusClassifier(256).to(device)
249
  model.load_state_dict(state_dict)
 
260
  classification = "Human" if prob_human > 0.5 else "Non-human"
261
  confidence = max(prob_human, prob_nonhuman)
262
 
 
263
  shap_means = compute_positionwise_scores(seq, shap_values, k=4)
264
  max_start, max_end, max_avg = find_extreme_subregion(shap_means, window_size, mode="max")
265
  min_start, min_end, min_avg = find_extreme_subregion(shap_means, window_size, mode="min")
266
 
 
267
  results_text = (
268
  f"Sequence: {header}\n"
269
  f"Length: {len(seq):,} bases\n"
 
277
  f"Start: {min_start}, End: {min_end}, Avg SHAP: {min_avg:.4f}"
278
  )
279
 
 
280
  kmers = [''.join(p) for p in product("ACGT", repeat=4)]
281
  bar_fig = create_importance_bar_plot(shap_values, kmers, top_kmers)
282
  bar_img = fig_to_image(bar_fig)
 
284
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
285
  heatmap_img = fig_to_image(heatmap_fig)
286
 
287
+ # You might want to provide a CSV or other data for the 6th return item
288
+ # Here, we'll simply return None for the file download:
289
  state_dict_out = {"seq": seq, "shap_means": shap_means}
290
 
 
291
  return (results_text, bar_img, heatmap_img, state_dict_out, header, None)
292
 
293
  ###############################################################################
 
295
  ###############################################################################
296
 
297
  def analyze_subregion(state, header, region_start, region_end):
 
 
 
298
  if not state or "seq" not in state or "shap_means" not in state:
299
  return ("No sequence data found. Please run Step 1 first.", None, None, None)
300
  seq = state["seq"]
 
305
  region_end = max(0, min(region_end, len(seq)))
306
  if region_end <= region_start:
307
  return ("Invalid region range. End must be > Start.", None, None, None)
 
308
  region_seq = seq[region_start:region_end]
309
  region_shap = shap_means[region_start:region_end]
 
310
  gc_percent = compute_gc_content(region_seq)
311
  avg_shap = float(np.mean(region_shap))
312
  positive_fraction = np.mean(region_shap > 0)
313
  negative_fraction = np.mean(region_shap < 0)
 
314
  if avg_shap > 0.05:
315
  region_classification = "Likely pushing toward human"
316
  elif avg_shap < -0.05:
317
  region_classification = "Likely pushing toward non-human"
318
  else:
319
  region_classification = "Near neutral (no strong push)"
 
320
  region_info = (
321
  f"Analyzing subregion of {header} from {region_start} to {region_end}\n"
322
  f"Region length: {len(region_seq)} bases\n"
 
326
  f"Fraction with SHAP < 0 (toward non-human): {negative_fraction:.2f}\n"
327
  f"Subregion interpretation: {region_classification}\n"
328
  )
 
329
  heatmap_fig = plot_linear_heatmap(shap_means, title="Subregion SHAP", start=region_start, end=region_end)
330
  heatmap_img = fig_to_image(heatmap_fig)
 
331
  hist_fig = plot_shap_histogram(region_shap, title="SHAP Distribution in Subregion")
332
  hist_img = fig_to_image(hist_fig)
333
+
334
+ # For demonstration, returning None for the file download as well
335
  return (region_info, heatmap_img, hist_img, None)
336
 
337
  ###############################################################################
338
+ # 9. COMPARISON ANALYSIS FUNCTIONS
339
  ###############################################################################
340
 
341
+ def get_zero_centered_cmap():
342
+ """Create a zero-centered blue-white-red colormap"""
343
+ colors = [(0.0, 'blue'), (0.5, 'white'), (1.0, 'red')]
344
+ return mcolors.LinearSegmentedColormap.from_list("blue_white_red", colors)
345
+
346
  def compute_shap_difference(shap1_norm, shap2_norm):
347
+ """Compute the SHAP difference between normalized sequences"""
 
 
348
  return shap2_norm - shap1_norm
349
 
350
  def plot_comparative_heatmap(shap_diff, title="SHAP Difference Heatmap"):
351
  """
352
+ Plot heatmap using relative positions (0-100%)
353
  """
354
  heatmap_data = shap_diff.reshape(1, -1)
355
  extent = max(abs(np.min(shap_diff)), abs(np.max(shap_diff)))
 
378
 
379
  def plot_shap_histogram(shap_array, title="SHAP Distribution", num_bins=30):
380
  """
381
+ Plot histogram of SHAP values with configurable number of bins
382
  """
383
  fig, ax = plt.subplots(figsize=(6, 4))
384
  ax.hist(shap_array, bins=num_bins, color='gray', edgecolor='black', alpha=0.7)
 
392
 
393
  def calculate_adaptive_parameters(len1, len2):
394
  """
395
+ Calculate adaptive parameters based on sequence lengths and their difference.
396
+ Returns: (num_points, smooth_window, resolution_factor)
397
  """
398
  length_diff = abs(len1 - len2)
399
  max_length = max(len1, len2)
400
  min_length = min(len1, len2)
401
  length_ratio = min_length / max_length
402
 
403
+ # Base number of points scales with sequence length
404
  base_points = min(2000, max(500, max_length // 100))
405
 
406
+ # Adjust parameters based on sequence properties
407
  if length_diff < 500:
408
  resolution_factor = 2.0
409
  num_points = min(3000, base_points * 2)
 
421
  num_points = max(500, base_points // 2)
422
  smooth_window = max(100, length_diff // 500)
423
 
424
+ # Adjust window size based on length ratio
425
  smooth_window = int(smooth_window * (1 + (1 - length_ratio)))
426
+
427
  return int(num_points), int(smooth_window), resolution_factor
428
 
429
  def sliding_window_smooth(values, window_size=50):
430
  """
431
+ Apply sliding window smoothing with edge handling
432
  """
433
  if window_size < 3:
434
  return values
435
+
436
+ # Create window with exponential decay at edges
437
  window = np.ones(window_size)
438
  decay = np.exp(-np.linspace(0, 3, window_size // 2))
439
  window[:window_size // 2] = decay
440
  window[-(window_size // 2):] = decay[::-1]
441
  window = window / window.sum()
442
 
443
+ # Apply convolution
444
  smoothed = np.convolve(values, window, mode='valid')
445
+
446
+ # Handle edges
447
  pad_size = len(values) - len(smoothed)
448
  pad_left = pad_size // 2
449
  pad_right = pad_size - pad_left
 
457
 
458
  def normalize_shap_lengths(shap1, shap2):
459
  """
460
+ Normalize and smooth SHAP values with dynamic adaptation
461
  """
462
+ # Calculate adaptive parameters
463
  num_points, smooth_window, _ = calculate_adaptive_parameters(len(shap1), len(shap2))
464
 
465
+ # Apply initial smoothing
466
  shap1_smooth = sliding_window_smooth(shap1, smooth_window)
467
  shap2_smooth = sliding_window_smooth(shap2, smooth_window)
468
 
469
+ # Create relative positions and interpolate
470
  x1 = np.linspace(0, 1, len(shap1_smooth))
471
  x2 = np.linspace(0, 1, len(shap2_smooth))
472
  x_norm = np.linspace(0, 1, num_points)
 
478
 
479
  def analyze_sequence_comparison(file1, file2, fasta1="", fasta2=""):
480
  """
481
+ Compare two sequences with adaptive parameters and visualization
 
482
  """
483
  try:
484
  # Analyze first sequence
 
491
  if isinstance(res2[0], str) and "Error" in res2[0]:
492
  return (f"Error in sequence 2: {res2[0]}", None, None, None)
493
 
494
+ # Extract SHAP values and sequence info
495
  shap1 = res1[3]["shap_means"]
496
  shap2 = res2[3]["shap_means"]
497
 
498
+ # Calculate sequence properties
499
  len1, len2 = len(shap1), len(shap2)
500
  length_diff = abs(len1 - len2)
501
  length_ratio = min(len1, len2) / max(len1, len2)
502
+
503
+ # Normalize and compare sequences
504
  shap1_norm, shap2_norm, smooth_window = normalize_shap_lengths(shap1, shap2)
505
  shap_diff = compute_shap_difference(shap1_norm, shap2_norm)
506
 
507
+ # Calculate adaptive threshold and statistics
508
  base_threshold = 0.05
509
  adaptive_threshold = base_threshold * (1 + (1 - length_ratio))
510
  if length_diff > 50000:
511
  adaptive_threshold *= 1.5
512
 
513
+ # Calculate comparison statistics
514
  avg_diff = np.mean(shap_diff)
515
  std_diff = np.std(shap_diff)
516
  max_diff = np.max(shap_diff)
 
518
  substantial_diffs = np.abs(shap_diff) > adaptive_threshold
519
  frac_different = np.mean(substantial_diffs)
520
 
521
+ # Extract classifications
522
  try:
523
  classification1 = res1[0].split('Classification: ')[1].split('\n')[0].strip()
524
  classification2 = res2[0].split('Classification: ')[1].split('\n')[0].strip()
 
526
  classification1 = "Unknown"
527
  classification2 = "Unknown"
528
 
529
+ # Format output text
530
  comparison_text = (
531
  "Sequence Comparison Results:\n"
532
  f"Sequence 1: {res1[4]}\n"
 
553
  "- White regions: Similar between sequences"
554
  )
555
 
556
+ # Generate visualizations
557
  heatmap_fig = plot_comparative_heatmap(
558
  shap_diff,
559
  title=f"SHAP Difference Heatmap (window: {smooth_window})"
560
  )
561
  heatmap_img = fig_to_image(heatmap_fig)
562
 
563
+ # Create histogram with adaptive bins
564
  num_bins = max(20, min(50, int(np.sqrt(len(shap_diff)))))
565
  hist_fig = plot_shap_histogram(
566
  shap_diff,
 
569
  )
570
  hist_img = fig_to_image(hist_fig)
571
 
572
+ # Return 4 outputs (text, image, image, and a file or None for the last)
573
  return (comparison_text, heatmap_img, hist_img, None)
574
 
575
  except Exception as e:
576
  error_msg = f"Error during sequence comparison: {str(e)}"
577
  return (error_msg, None, None, None)
578
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
579
  ###############################################################################
580
  # 11. GENE FEATURE ANALYSIS
581
  ###############################################################################
582
 
583
+ import io
584
+ from io import BytesIO
585
+ from PIL import Image, ImageDraw, ImageFont
586
+ import numpy as np
587
+ import pandas as pd
588
+ import tempfile
589
+ import os
590
+ from typing import List, Dict, Tuple, Optional, Any
591
+ import matplotlib.pyplot as plt
592
+ from matplotlib.colors import LinearSegmentedColormap
593
+ import seaborn as sns
594
+
595
  def parse_gene_features(text: str) -> List[Dict[str, Any]]:
596
+ """Parse gene features from text file in FASTA-like format"""
597
  genes = []
598
  current_header = None
599
  current_sequence = []
 
602
  line = line.strip()
603
  if not line:
604
  continue
605
+
606
  if line.startswith('>'):
607
  if current_header:
608
  genes.append({
 
614
  current_sequence = []
615
  else:
616
  current_sequence.append(line.upper())
617
+
618
  if current_header:
619
  genes.append({
620
  'header': current_header,
621
  'sequence': ''.join(current_sequence),
622
  'metadata': parse_gene_metadata(current_header)
623
  })
624
+
625
  return genes
626
 
627
  def parse_gene_metadata(header: str) -> Dict[str, str]:
628
+ """Extract metadata from gene header"""
629
  metadata = {}
630
  parts = header.split()
631
+
632
  for part in parts:
633
  if '[' in part and ']' in part:
634
  key_value = part[1:-1].split('=', 1)
635
  if len(key_value) == 2:
636
  metadata[key_value[0]] = key_value[1]
637
+
638
  return metadata
639
 
640
  def parse_location(location_str: str) -> Tuple[Optional[int], Optional[int]]:
641
+ """Parse gene location string, handling both forward and complement strands"""
642
  try:
643
+ # Remove 'complement(' and ')' if present
644
  clean_loc = location_str.replace('complement(', '').replace(')', '')
645
+
646
+ # Split on '..' and convert to integers
647
  if '..' in clean_loc:
648
  start, end = map(int, clean_loc.split('..'))
649
  return start, end
 
654
  return None, None
655
 
656
  def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
657
+ """Compute statistical measures for gene SHAP values"""
658
  return {
659
+ 'avg_shap': float(np.mean(gene_shap)),
660
+ 'median_shap': float(np.median(gene_shap)),
661
+ 'std_shap': float(np.std(gene_shap)),
662
+ 'max_shap': float(np.max(gene_shap)),
663
+ 'min_shap': float(np.min(gene_shap)),
664
+ 'pos_fraction': float(np.mean(gene_shap > 0))
665
  }
666
 
667
  def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
668
  """
669
+ Create a simple genome diagram using PIL, forcing a minimum color intensity
670
+ so that small SHAP values don't appear white.
671
  """
672
+ from PIL import Image, ImageDraw, ImageFont
673
+
674
+ # Validate inputs
675
  if not gene_results or genome_length <= 0:
676
  img = Image.new('RGB', (800, 100), color='white')
677
  draw = ImageDraw.Draw(img)
678
  draw.text((10, 40), "Error: Invalid input data", fill='black')
679
  return img
680
+
681
+ # Ensure all gene coordinates are valid integers
682
  for gene in gene_results:
683
  gene['start'] = max(0, int(gene['start']))
684
  gene['end'] = min(genome_length, int(gene['end']))
685
  if gene['start'] >= gene['end']:
686
+ print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}: {gene['start']}-{gene['end']}")
687
 
688
+ # Image dimensions
689
  width = 1500
690
  height = 600
691
  margin = 50
692
  track_height = 40
693
 
694
+ # Create image with white background
695
  img = Image.new('RGB', (width, height), 'white')
696
  draw = ImageDraw.Draw(img)
697
 
698
+ # Try to load font, fall back to default if unavailable
699
  try:
700
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
701
  title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
 
703
  font = ImageFont.load_default()
704
  title_font = ImageFont.load_default()
705
 
706
+ # Draw title
707
+ draw.text((margin, margin // 2), "Genome SHAP Analysis", fill='black', font=title_font or font)
708
 
709
+ # Draw genome line
710
  line_y = height // 2
711
  draw.line([(int(margin), int(line_y)), (int(width - margin), int(line_y))], fill='black', width=2)
712
 
713
+ # Calculate scale factor
714
  scale = float(width - 2 * margin) / float(genome_length)
715
 
716
+ # Determine a reasonable step for scale markers
717
  num_ticks = 10
718
+ if genome_length < num_ticks:
719
+ step = 1
720
+ else:
721
+ step = genome_length // num_ticks
722
+
723
+ # Draw scale markers
724
  for i in range(0, genome_length + 1, step):
725
  x_coord = margin + i * scale
726
  draw.line([
 
729
  ], fill='black', width=1)
730
  draw.text((int(x_coord - 20), int(line_y + 10)), f"{i:,}", fill='black', font=font)
731
 
732
+ # Sort genes by absolute SHAP value for drawing
733
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
734
+
735
+ # Draw genes
736
  for idx, gene in enumerate(sorted_genes):
737
+ # Calculate position and ensure integers
738
  start_x = margin + int(gene['start'] * scale)
739
  end_x = margin + int(gene['end'] * scale)
740
+
741
+ # Calculate color based on SHAP value
742
  avg_shap = gene['avg_shap']
743
+
744
+ # Convert shap -> color intensity (0 to 255)
745
+ # Then clamp to a minimum intensity so it never ends up plain white
746
  intensity = int(abs(avg_shap) * 500)
747
+ intensity = max(50, min(255, intensity)) # clamp between 50 and 255
748
 
749
  if avg_shap > 0:
750
+ # Red-ish for positive
751
+ color = (255, 255 - intensity, 255 - intensity)
752
  else:
753
+ # Blue-ish for negative or zero
754
+ color = (255 - intensity, 255 - intensity, 255)
755
 
756
+ # Draw gene rectangle
757
  draw.rectangle([
758
  (int(start_x), int(line_y - track_height // 2)),
759
  (int(end_x), int(line_y + track_height // 2))
760
  ], fill=color, outline='black')
761
 
762
+ # Prepare gene name label
763
  label = str(gene.get('gene_name','?'))
764
+
765
+ # Fallback for label size
766
  label_mask = font.getmask(label)
767
  label_width, label_height = label_mask.size
768
 
769
+ # Alternate label positions
770
  if idx % 2 == 0:
771
  text_y = line_y - track_height - 15
772
  else:
773
  text_y = line_y + track_height + 5
774
 
775
+ # Decide whether to rotate text based on space
776
  gene_width = end_x - start_x
777
  if gene_width > label_width:
778
  text_x = start_x + (gene_width - label_width) // 2
 
784
  rotated_img = txt_img.rotate(90, expand=True)
785
  img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
786
 
787
+ # Draw legend
788
+ legend_x = margin
789
+ legend_y = height - margin
790
+ draw.text((int(legend_x), int(legend_y - 60)), "SHAP Values:", fill='black', font=font)
791
+
792
+ # Draw legend boxes
793
+ box_width = 20
794
+ box_height = 20
795
+ spacing = 15
796
+
797
+ # Strong human-like
798
+ draw.rectangle([
799
+ (int(legend_x), int(legend_y - 45)),
800
+ (int(legend_x + box_width), int(legend_y - 45 + box_height))
801
+ ], fill=(255, 0, 0), outline='black')
802
+ draw.text((int(legend_x + box_width + spacing), int(legend_y - 45)),
803
+ "Strong human-like signal", fill='black', font=font)
804
+
805
+ # Weak human-like
806
+ draw.rectangle([
807
+ (int(legend_x), int(legend_y - 20)),
808
+ (int(legend_x + box_width), int(legend_y - 20 + box_height))
809
+ ], fill=(255, 200, 200), outline='black')
810
+ draw.text((int(legend_x + box_width + spacing), int(legend_y - 20)),
811
+ "Weak human-like signal", fill='black', font=font)
812
+
813
+ # Weak non-human-like
814
+ draw.rectangle([
815
+ (int(legend_x + 250), int(legend_y - 45)),
816
+ (int(legend_x + 250 + box_width), int(legend_y - 45 + box_height))
817
+ ], fill=(200, 200, 255), outline='black')
818
+ draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 45)),
819
+ "Weak non-human-like signal", fill='black', font=font)
820
+
821
+ # Strong non-human-like
822
+ draw.rectangle([
823
+ (int(legend_x + 250), int(legend_y - 20)),
824
+ (int(legend_x + 250 + box_width), int(legend_y - 20 + box_height))
825
+ ], fill=(0, 0, 255), outline='black')
826
+ draw.text((int(legend_x + 250 + box_width + spacing), int(legend_y - 20)),
827
+ "Strong non-human-like signal", fill='black', font=font)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
828
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
829
  return img
830
 
831
  def analyze_gene_features(sequence_file: str,
832
  features_file: str,
833
  fasta_text: str = "",
834
+ features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
835
+ """Analyze SHAP values for each gene feature"""
836
+ # First analyze whole sequence
 
 
 
 
 
837
  sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
838
  if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
839
  return f"Error in sequence analysis: {sequence_results[0]}", None, None
840
 
841
+ # Get SHAP values
842
  shap_means = sequence_results[3]["shap_means"]
843
+
844
+ # Parse gene features
 
845
  try:
846
  if features_text.strip():
847
  genes = parse_gene_features(features_text)
 
850
  genes = parse_gene_features(f.read())
851
  except Exception as e:
852
  return f"Error reading features file: {str(e)}", None, None
853
+
854
+ # Analyze each gene
855
  gene_results = []
856
  for gene in genes:
857
+ try:
858
+ location = gene['metadata'].get('location', '')
859
+ if not location:
860
+ continue
861
+
862
+ start, end = parse_location(location)
863
+ if start is None or end is None:
864
+ continue
865
+
866
+ # Get SHAP values for this region
867
+ gene_shap = shap_means[start:end]
868
+ stats = compute_gene_statistics(gene_shap)
869
+
870
+ gene_results.append({
871
+ 'gene_name': gene['metadata'].get('gene', 'Unknown'),
872
+ 'location': location,
873
+ 'start': start,
874
+ 'end': end,
875
+ 'locus_tag': gene['metadata'].get('locus_tag', ''),
876
+ 'avg_shap': stats['avg_shap'],
877
+ 'median_shap': stats['median_shap'],
878
+ 'std_shap': stats['std_shap'],
879
+ 'max_shap': stats['max_shap'],
880
+ 'min_shap': stats['min_shap'],
881
+ 'pos_fraction': stats['pos_fraction'],
882
+ 'classification': 'Human' if stats['avg_shap'] > 0 else 'Non-human',
883
+ 'confidence': abs(stats['avg_shap'])
884
+ })
885
+
886
+ except Exception as e:
887
+ print(f"Error processing gene {gene['metadata'].get('gene', 'Unknown')}: {str(e)}")
888
  continue
889
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
890
  if not gene_results:
891
  return "No valid genes could be processed", None, None
892
+
893
+ # Sort genes by absolute SHAP value
894
  sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
895
+
896
+ # Create results text
897
  results_text = "Gene Analysis Results:\n\n"
898
  results_text += f"Total genes analyzed: {len(gene_results)}\n"
899
+ results_text += f"Human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Human')}\n"
900
+ results_text += f"Non-human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Non-human')}\n\n"
 
901
 
902
+ results_text += "Top 10 most distinctive genes:\n"
903
  for gene in sorted_genes[:10]:
904
  results_text += (
905
  f"Gene: {gene['gene_name']}\n"
906
  f"Location: {gene['location']}\n"
907
  f"Classification: {gene['classification']} "
908
  f"(confidence: {gene['confidence']:.4f})\n"
909
+ f"Average SHAP: {gene['avg_shap']:.4f}\n\n"
 
910
  )
911
+
912
+ # Create CSV content
913
+ csv_content = "gene_name,location,avg_shap,median_shap,std_shap,max_shap,min_shap,"
914
+ csv_content += "pos_fraction,classification,confidence,locus_tag\n"
915
+
916
+ for gene in gene_results:
917
  csv_content += (
918
+ f"{gene['gene_name']},{gene['location']},{gene['avg_shap']:.4f},"
919
+ f"{gene['median_shap']:.4f},{gene['std_shap']:.4f},{gene['max_shap']:.4f},"
920
+ f"{gene['min_shap']:.4f},{gene['pos_fraction']:.4f},{gene['classification']},"
921
+ f"{gene['confidence']:.4f},{gene['locus_tag']}\n"
922
  )
923
+
924
+ # Save CSV to temp file
925
  try:
926
  temp_dir = tempfile.gettempdir()
927
  temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv")
928
+
929
  with open(temp_path, 'w') as f:
930
  f.write(csv_content)
931
  except Exception as e:
932
  print(f"Error saving CSV: {str(e)}")
933
  temp_path = None
934
+
935
+ # Create visualization
936
  try:
937
+ diagram_img = create_simple_genome_diagram(gene_results, len(shap_means))
 
 
 
938
  except Exception as e:
939
  print(f"Error creating visualization: {str(e)}")
940
+ # Create error image
941
  diagram_img = Image.new('RGB', (800, 100), color='white')
942
  draw = ImageDraw.Draw(diagram_img)
943
  draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black')
944
+
945
  return results_text, temp_path, diagram_img
946
 
947
  ###############################################################################
 
949
  ###############################################################################
950
 
951
  def prepare_csv_download(data, filename="analysis_results.csv"):
952
+ """Prepare CSV data for download"""
 
 
953
  if isinstance(data, str):
954
  return data.encode(), filename
955
  elif isinstance(data, (list, dict)):
956
  import csv
957
  from io import StringIO
958
+
959
  output = StringIO()
960
  writer = csv.DictWriter(output, fieldnames=data[0].keys())
961
  writer.writeheader()
 
979
 
980
  with gr.Blocks(css=css) as iface:
981
  gr.Markdown("""
982
+ # Virus Host Classifier
983
+ **Step 1**: Predict overall viral sequence origin (human vs non-human) and identify extreme regions.
984
+ **Step 2**: Explore subregions to see local SHAP signals, distribution, GC content, etc.
985
+ **Step 3**: Analyze gene features and their contributions.
986
+ **Step 4**: Compare sequences and analyze differences.
987
+
988
+ **Color Scale**: Negative SHAP = Blue, Zero = White, Positive SHAP = Red.
989
  """)
990
 
991
  with gr.Tab("1) Full-Sequence Analysis"):
992
  with gr.Row():
993
  with gr.Column(scale=1):
994
  file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
995
+ text_input = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5)
996
  top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
997
+ win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window size for 'most pushing' subregions")
998
  analyze_btn = gr.Button("Analyze Sequence", variant="primary")
999
  with gr.Column(scale=2):
1000
  results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
 
1013
  with gr.Tab("2) Subregion Exploration"):
1014
  gr.Markdown("""
1015
  **Subregion Analysis**
1016
+ Select start/end positions to view local SHAP signals, distribution, GC content, etc.
1017
+ The heatmap uses the same Blue-White-Red scale.
1018
  """)
1019
  with gr.Row():
1020
  region_start = gr.Number(label="Region Start", value=0)
 
1024
  with gr.Row():
1025
  subregion_img = gr.Image(label="Subregion SHAP Heatmap (B-W-R)")
1026
  subregion_hist_img = gr.Image(label="SHAP Distribution (Histogram)")
1027
+ download_subregion = gr.File(label="Download Subregion Analysis", visible=False, elem_classes="download-button")
1028
 
1029
  region_btn.click(
1030
  analyze_subregion,
 
1035
  with gr.Tab("3) Gene Features Analysis"):
1036
  gr.Markdown("""
1037
  **Analyze Gene Features**
1038
+ Upload a FASTA file and corresponding gene features file to analyze SHAP values per gene.
1039
+ Gene features should be in the format:
1040
+ ```
1041
+ >gene_name [gene=X] [locus_tag=Y] [location=start..end] or [location=complement(start..end)]
1042
+ SEQUENCE
1043
+ ```
1044
+ The genome viewer will show genes color-coded by their contribution:
1045
+ - Red: Genes pushing toward human origin
1046
+ - Blue: Genes pushing toward non-human origin
1047
+ - Color intensity indicates strength of signal
1048
  """)
1049
  with gr.Row():
1050
  with gr.Column(scale=1):
1051
+ gene_fasta_file = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1052
+ gene_fasta_text = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5)
1053
  with gr.Column(scale=1):
1054
+ features_file = gr.File(label="Upload gene features file", file_types=[".txt"], type="filepath")
1055
+ features_text = gr.Textbox(label="Or paste gene features", placeholder=">gene_1 [gene=U12]...\nACGT...", lines=5)
1056
+
1057
  analyze_genes_btn = gr.Button("Analyze Gene Features", variant="primary")
1058
  gene_results = gr.Textbox(label="Gene Analysis Results", lines=12, interactive=False)
1059
+ gene_diagram = gr.Image(label="Genome Diagram with Gene Features")
1060
  download_gene_results = gr.File(label="Download Gene Analysis (CSV)", visible=True)
1061
 
1062
  analyze_genes_btn.click(
1063
  analyze_gene_features,
1064
+ inputs=[gene_fasta_file, features_file, gene_fasta_text, features_text],
1065
  outputs=[gene_results, download_gene_results, gene_diagram]
1066
  )
1067
 
1068
  with gr.Tab("4) Comparative Analysis"):
1069
  gr.Markdown("""
1070
  **Compare Two Sequences**
1071
+ Upload or paste two FASTA sequences to compare their SHAP patterns.
1072
+ The sequences will be normalized to the same length for comparison.
1073
+
1074
+ **Color Scale**:
1075
+ - Red: Sequence 2 more human-like
1076
+ - Blue: Sequence 1 more human-like
1077
+ - White: No substantial difference
1078
  """)
1079
  with gr.Row():
1080
  with gr.Column(scale=1):
1081
+ file_input1 = gr.File(label="Upload first FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1082
+ text_input1 = gr.Textbox(label="Or paste first FASTA sequence", placeholder=">sequence1\nACGTACGT...", lines=5)
1083
  with gr.Column(scale=1):
1084
+ file_input2 = gr.File(label="Upload second FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1085
+ text_input2 = gr.Textbox(label="Or paste second FASTA sequence", placeholder=">sequence2\nACGTACGT...", lines=5)
1086
  compare_btn = gr.Button("Compare Sequences", variant="primary")
1087
  comparison_text = gr.Textbox(label="Comparison Results", lines=12, interactive=False)
1088
  with gr.Row():
1089
  diff_heatmap = gr.Image(label="SHAP Difference Heatmap")
1090
  diff_hist = gr.Image(label="Distribution of SHAP Differences")
1091
+ download_comparison = gr.File(label="Download Comparison Results", visible=False, elem_classes="download-button")
1092
 
1093
  compare_btn.click(
1094
  analyze_sequence_comparison,
 
1097
  )
1098
 
1099
  gr.Markdown("""
1100
+ ### Interface Features
1101
+ - **Overall Classification** (human vs non-human) using k-mer frequencies
1102
+ - **SHAP Analysis** shows which k-mers push classification toward or away from human
1103
+ - **White-Centered SHAP Gradient**:
1104
+ - Negative (blue), 0 (white), Positive (red)
1105
+ - Symmetrical color range around 0
1106
+ - **Identify Subregions** with strongest push for human or non-human
1107
+ - **Gene Feature Analysis**:
1108
+ - Analyze individual genes' contributions
1109
+ - Interactive genome viewer
1110
+ - Gene-level statistics and classification
1111
+ - **Sequence Comparison**:
1112
+ - Compare two sequences to identify regions of difference
1113
+ - Normalized comparison to handle different lengths
1114
+ - Statistical summary of differences
1115
+ - **Data Export**:
1116
+ - Download results as CSV files
1117
+ - Save analysis outputs for further processing
1118
  """)
1119
+
1120
  if __name__ == "__main__":
1121
  iface.launch()