hiyata commited on
Commit
9f540d3
·
verified ·
1 Parent(s): 3259cc8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -16
app.py CHANGED
@@ -644,7 +644,10 @@ def compute_gene_statistics(gene_shap: np.ndarray) -> Dict[str, float]:
644
  }
645
 
646
  def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
647
- """Create a simple genome diagram using PIL"""
 
 
 
648
  from PIL import Image, ImageDraw, ImageFont
649
 
650
  # Validate inputs
@@ -659,7 +662,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
659
  gene['start'] = max(0, int(gene['start']))
660
  gene['end'] = min(genome_length, int(gene['end']))
661
  if gene['start'] >= gene['end']:
662
- print(f"Warning: Invalid coordinates for gene {gene['gene_name']}: {gene['start']}-{gene['end']}")
663
 
664
  # Image dimensions
665
  width = 1500
@@ -689,7 +692,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
689
  # Calculate scale factor
690
  scale = float(width - 2 * margin) / float(genome_length)
691
 
692
- # Determine a reasonable step for scale markers (avoid zero step if genome_length < 10)
693
  num_ticks = 10
694
  if genome_length < num_ticks:
695
  step = 1
@@ -712,43 +715,50 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
712
  for idx, gene in enumerate(sorted_genes):
713
  # Calculate position and ensure integers
714
  start_x = margin + int(gene['start'] * scale)
715
- end_x = margin + int(gene['end'] * scale)
716
 
717
  # Calculate color based on SHAP value
718
- if gene['avg_shap'] > 0:
719
- intensity = min(255, int(abs(gene['avg_shap'] * 500)))
720
- color = (255, max(0, 255 - intensity), max(0, 255 - intensity)) # Red-ish
 
 
 
 
 
 
 
721
  else:
722
- intensity = min(255, int(abs(gene['avg_shap'] * 500)))
723
- color = (max(0, 255 - intensity), max(0, 255 - intensity), 255) # Blue-ish
724
 
725
- # Draw gene box
726
  draw.rectangle([
727
  (int(start_x), int(line_y - track_height // 2)),
728
  (int(end_x), int(line_y + track_height // 2))
729
  ], fill=color, outline='black')
730
 
731
  # Prepare gene name label
732
- label = f"{gene['gene_name']}"
733
 
734
- # Fallback approach: use getmask(...) to get (width, height)
 
 
735
  label_mask = font.getmask(label)
736
  label_width, label_height = label_mask.size
737
 
738
- # Alternate label positions above/below
739
  if idx % 2 == 0:
740
  text_y = line_y - track_height - 15
741
  else:
742
  text_y = line_y + track_height + 5
743
 
744
- # Decide to rotate text or not based on available box width
745
  gene_width = end_x - start_x
746
  if gene_width > label_width:
747
- # Draw horizontally
748
  text_x = start_x + (gene_width - label_width) // 2
749
  draw.text((int(text_x), int(text_y)), label, fill='black', font=font)
750
  elif gene_width > 20:
751
- # Create rotated text
752
  txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0))
753
  txt_draw = ImageDraw.Draw(txt_img)
754
  txt_draw.text((0, 0), label, font=font, fill='black')
@@ -800,6 +810,7 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
800
  return img
801
 
802
 
 
803
  def analyze_gene_features(sequence_file: str,
804
  features_file: str,
805
  fasta_text: str = "",
 
644
  }
645
 
646
  def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_length: int) -> Image.Image:
647
+ """
648
+ Create a simple genome diagram using PIL, forcing a minimum color intensity
649
+ so that small SHAP values don't appear white.
650
+ """
651
  from PIL import Image, ImageDraw, ImageFont
652
 
653
  # Validate inputs
 
662
  gene['start'] = max(0, int(gene['start']))
663
  gene['end'] = min(genome_length, int(gene['end']))
664
  if gene['start'] >= gene['end']:
665
+ print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}: {gene['start']}-{gene['end']}")
666
 
667
  # Image dimensions
668
  width = 1500
 
692
  # Calculate scale factor
693
  scale = float(width - 2 * margin) / float(genome_length)
694
 
695
+ # Determine a reasonable step for scale markers
696
  num_ticks = 10
697
  if genome_length < num_ticks:
698
  step = 1
 
715
  for idx, gene in enumerate(sorted_genes):
716
  # Calculate position and ensure integers
717
  start_x = margin + int(gene['start'] * scale)
718
+ end_x = margin + int(gene['end'] * scale)
719
 
720
  # Calculate color based on SHAP value
721
+ avg_shap = gene['avg_shap']
722
+
723
+ # Convert shap -> color intensity (0 to 255)
724
+ # Then clamp to a minimum intensity so it never ends up plain white
725
+ intensity = int(abs(avg_shap) * 500)
726
+ intensity = max(50, min(255, intensity)) # clamp between 50 and 255
727
+
728
+ if avg_shap > 0:
729
+ # Red-ish for positive
730
+ color = (255, 255 - intensity, 255 - intensity)
731
  else:
732
+ # Blue-ish for negative or zero
733
+ color = (255 - intensity, 255 - intensity, 255)
734
 
735
+ # Draw gene rectangle
736
  draw.rectangle([
737
  (int(start_x), int(line_y - track_height // 2)),
738
  (int(end_x), int(line_y + track_height // 2))
739
  ], fill=color, outline='black')
740
 
741
  # Prepare gene name label
742
+ label = str(gene.get('gene_name','?'))
743
 
744
+ # If getsize() or textsize() is missing, use getmask(...).size as fallback
745
+ # But if your Pillow version supports font.getsize, you can do:
746
+ # label_width, label_height = font.getsize(label)
747
  label_mask = font.getmask(label)
748
  label_width, label_height = label_mask.size
749
 
750
+ # Alternate label positions above/below line
751
  if idx % 2 == 0:
752
  text_y = line_y - track_height - 15
753
  else:
754
  text_y = line_y + track_height + 5
755
 
756
+ # Decide whether to rotate text based on space
757
  gene_width = end_x - start_x
758
  if gene_width > label_width:
 
759
  text_x = start_x + (gene_width - label_width) // 2
760
  draw.text((int(text_x), int(text_y)), label, fill='black', font=font)
761
  elif gene_width > 20:
 
762
  txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0))
763
  txt_draw = ImageDraw.Draw(txt_img)
764
  txt_draw.text((0, 0), label, font=font, fill='black')
 
810
  return img
811
 
812
 
813
+
814
  def analyze_gene_features(sequence_file: str,
815
  features_file: str,
816
  fasta_text: str = "",