hiyata commited on
Commit
7d672a0
·
verified ·
1 Parent(s): 9f540d3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -117
app.py CHANGED
@@ -811,125 +811,152 @@ def create_simple_genome_diagram(gene_results: List[Dict[str, Any]], genome_leng
811
 
812
 
813
 
814
- def analyze_gene_features(sequence_file: str,
815
- features_file: str,
816
- fasta_text: str = "",
817
- features_text: str = "") -> Tuple[str, Optional[str], Optional[Image.Image]]:
818
- """
819
- Analyze SHAP values for each gene feature.
820
- NOTE: This function assumes there's an `analyze_sequence(...)` function
821
- defined elsewhere that returns the needed SHAP information.
822
- """
823
- # First analyze whole sequence
824
- sequence_results = analyze_sequence(sequence_file, top_kmers=10, fasta_text=fasta_text)
825
- if isinstance(sequence_results[0], str) and "Error" in sequence_results[0]:
826
- return f"Error in sequence analysis: {sequence_results[0]}", None, None
827
-
828
- # Get SHAP values
829
- shap_means = sequence_results[3]["shap_means"]
830
-
831
- # Parse gene features
832
- try:
833
- if features_text.strip():
834
- genes = parse_gene_features(features_text)
835
- else:
836
- with open(features_file, 'r') as f:
837
- genes = parse_gene_features(f.read())
838
- except Exception as e:
839
- return f"Error reading features file: {str(e)}", None, None
840
-
841
- # Analyze each gene
842
- gene_results = []
843
- for gene in genes:
844
- try:
845
- location = gene['metadata'].get('location', '')
846
- if not location:
847
- continue
848
-
849
- start, end = parse_location(location)
850
- if start is None or end is None:
851
- continue
852
-
853
- # Get SHAP values for this region
854
- gene_shap = shap_means[start:end]
855
- stats = compute_gene_statistics(gene_shap)
856
-
857
- gene_results.append({
858
- 'gene_name': gene['metadata'].get('gene', 'Unknown'),
859
- 'location': location,
860
- 'start': start,
861
- 'end': end,
862
- 'locus_tag': gene['metadata'].get('locus_tag', ''),
863
- 'avg_shap': stats['avg_shap'],
864
- 'median_shap': stats['median_shap'],
865
- 'std_shap': stats['std_shap'],
866
- 'max_shap': stats['max_shap'],
867
- 'min_shap': stats['min_shap'],
868
- 'pos_fraction': stats['pos_fraction'],
869
- 'classification': 'Human' if stats['avg_shap'] > 0 else 'Non-human',
870
- 'confidence': abs(stats['avg_shap'])
871
- })
872
-
873
- except Exception as e:
874
- print(f"Error processing gene {gene['metadata'].get('gene', 'Unknown')}: {str(e)}")
875
- continue
876
-
877
- if not gene_results:
878
- return "No valid genes could be processed", None, None
879
-
880
- # Sort genes by absolute SHAP value
881
- sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']), reverse=True)
882
-
883
- # Create results text
884
- results_text = "Gene Analysis Results:\n\n"
885
- results_text += f"Total genes analyzed: {len(gene_results)}\n"
886
- results_text += f"Human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Human')}\n"
887
- results_text += f"Non-human-like genes: {sum(1 for g in gene_results if g['classification'] == 'Non-human')}\n\n"
888
-
889
- results_text += "Top 10 most distinctive genes:\n"
890
- for gene in sorted_genes[:10]:
891
- results_text += (
892
- f"Gene: {gene['gene_name']}\n"
893
- f"Location: {gene['location']}\n"
894
- f"Classification: {gene['classification']} "
895
- f"(confidence: {gene['confidence']:.4f})\n"
896
- f"Average SHAP: {gene['avg_shap']:.4f}\n\n"
897
- )
898
-
899
- # Create CSV content
900
- csv_content = "gene_name,location,avg_shap,median_shap,std_shap,max_shap,min_shap,"
901
- csv_content += "pos_fraction,classification,confidence,locus_tag\n"
902
-
903
  for gene in gene_results:
904
- csv_content += (
905
- f"{gene['gene_name']},{gene['location']},{gene['avg_shap']:.4f},"
906
- f"{gene['median_shap']:.4f},{gene['std_shap']:.4f},{gene['max_shap']:.4f},"
907
- f"{gene['min_shap']:.4f},{gene['pos_fraction']:.4f},{gene['classification']},"
908
- f"{gene['confidence']:.4f},{gene['locus_tag']}\n"
909
- )
910
-
911
- # Save CSV to temp file
912
- try:
913
- temp_dir = tempfile.gettempdir()
914
- temp_path = os.path.join(temp_dir, f"gene_analysis_{os.urandom(4).hex()}.csv")
915
-
916
- with open(temp_path, 'w') as f:
917
- f.write(csv_content)
918
- except Exception as e:
919
- print(f"Error saving CSV: {str(e)}")
920
- temp_path = None
921
-
922
- # Create visualization
923
  try:
924
- diagram_img = create_simple_genome_diagram(gene_results, len(shap_means))
925
- except Exception as e:
926
- print(f"Error creating visualization: {str(e)}")
927
- # Create error image
928
- diagram_img = Image.new('RGB', (800, 100), color='white')
929
- draw = ImageDraw.Draw(diagram_img)
930
- draw.text((10, 40), f"Error creating visualization: {str(e)}", fill='black')
931
-
932
- return results_text, temp_path, diagram_img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
933
 
934
 
935
  ###############################################################################
 
811
 
812
 
813
 
814
+ def create_simple_genome_diagram(gene_results, genome_length):
815
+ from PIL import Image, ImageDraw, ImageFont
816
+
817
+ # Validate inputs
818
+ if not gene_results or genome_length <= 0:
819
+ img = Image.new('RGBA', (800, 100), color=(255, 255, 255, 255))
820
+ draw = ImageDraw.Draw(img, 'RGBA')
821
+ draw.text((10, 40), "Error: Invalid input data", fill='black')
822
+ return img
823
+
824
+ # Ensure valid gene coords
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
825
  for gene in gene_results:
826
+ gene['start'] = max(0, int(gene['start']))
827
+ gene['end'] = min(genome_length, int(gene['end']))
828
+ if gene['start'] >= gene['end']:
829
+ print(f"Warning: Invalid coordinates for gene {gene.get('gene_name','?')}: "
830
+ f"{gene['start']}-{gene['end']}")
831
+
832
+ # Dimensions
833
+ width, height = 1500, 600
834
+ margin = 50
835
+ track_height = 40
836
+
837
+ # Create RGBA image
838
+ img = Image.new('RGBA', (width, height), (255, 255, 255, 255))
839
+ draw = ImageDraw.Draw(img, 'RGBA')
840
+
841
+ # Fonts
 
 
 
842
  try:
843
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 12)
844
+ title_font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", 16)
845
+ except:
846
+ font = ImageFont.load_default()
847
+ title_font = ImageFont.load_default()
848
+
849
+ # Draw title text
850
+ draw.text((margin, margin // 2), "Genome SHAP Analysis", fill='black', font=title_font)
851
+
852
+ # Draw genome line & ticks FIRST (so rectangles are partially see-through)
853
+ line_y = height // 2
854
+ draw.line([(margin, line_y), (width - margin, line_y)], fill='black', width=2)
855
+
856
+ # Scale factor
857
+ scale = (width - 2 * margin) / float(genome_length)
858
+
859
+ # Ticks
860
+ num_ticks = 10
861
+ step = 1 if genome_length < num_ticks else (genome_length // num_ticks)
862
+ for i in range(0, genome_length + 1, step):
863
+ x_coord = margin + i * scale
864
+ draw.line([(int(x_coord), line_y - 5), (int(x_coord), line_y + 5)],
865
+ fill='black', width=1)
866
+ draw.text((int(x_coord - 20), line_y + 10), f"{i:,}", fill='black', font=font)
867
+
868
+ # Sort genes by absolute shap so smaller shap genes get drawn first
869
+ # (and partially appear behind bigger shap genes).
870
+ sorted_genes = sorted(gene_results, key=lambda x: abs(x['avg_shap']))
871
+
872
+ # Draw gene boxes with partial alpha
873
+ for idx, gene in enumerate(sorted_genes):
874
+ start_x = margin + int(gene['start'] * scale)
875
+ end_x = margin + int(gene['end'] * scale)
876
+
877
+ # Compute color
878
+ avg_shap = gene['avg_shap']
879
+ intensity = min(255, int(abs(avg_shap)*500))
880
+ # clamp a bit so it doesn't look white
881
+ intensity = max(50, intensity)
882
+
883
+ if avg_shap > 0:
884
+ # Red-ish
885
+ color = (255, 255 - intensity, 255 - intensity, 180)
886
+ else:
887
+ # Blue-ish
888
+ color = (255 - intensity, 255 - intensity, 255, 180)
889
+
890
+ # Partially transparent rectangle
891
+ draw.rectangle([
892
+ (start_x, line_y - track_height // 2),
893
+ (end_x, line_y + track_height // 2)
894
+ ], fill=color, outline=(0, 0, 0, 255))
895
+
896
+ # Label
897
+ label = gene.get('gene_name', '?')
898
+ label_mask = font.getmask(label)
899
+ label_width, label_height = label_mask.size
900
+
901
+ # Above or below
902
+ if idx % 2 == 0:
903
+ text_y = line_y - track_height - 15
904
+ else:
905
+ text_y = line_y + track_height + 5
906
+
907
+ # If there's room, draw horizontally; else rotate
908
+ gene_width = end_x - start_x
909
+ if gene_width > label_width:
910
+ text_x = start_x + (gene_width - label_width) // 2
911
+ draw.text((text_x, text_y), label, fill='black', font=font)
912
+ elif gene_width > 20:
913
+ txt_img = Image.new('RGBA', (label_width, label_height), (255, 255, 255, 0))
914
+ txt_draw = ImageDraw.Draw(txt_img)
915
+ txt_draw.text((0, 0), label, font=font, fill='black')
916
+ rotated_img = txt_img.rotate(90, expand=True)
917
+ img.paste(rotated_img, (int(start_x), int(text_y)), rotated_img)
918
+
919
+ # Legend
920
+ legend_x = margin
921
+ legend_y = height - margin
922
+ draw.text((legend_x, legend_y - 60), "SHAP Values:", fill='black', font=font)
923
+
924
+ box_width, box_height = 20, 20
925
+ spacing = 15
926
+ # strong human-like
927
+ draw.rectangle([
928
+ (legend_x, legend_y - 45),
929
+ (legend_x + box_width, legend_y - 45 + box_height)
930
+ ], fill=(255, 0, 0, 255), outline=(0, 0, 0, 255))
931
+ draw.text((legend_x + box_width + spacing, legend_y - 45),
932
+ "Strong human-like signal", fill='black', font=font)
933
+
934
+ # weak human-like
935
+ draw.rectangle([
936
+ (legend_x, legend_y - 20),
937
+ (legend_x + box_width, legend_y - 20 + box_height)
938
+ ], fill=(255, 200, 200, 255), outline=(0, 0, 0, 255))
939
+ draw.text((legend_x + box_width + spacing, legend_y - 20),
940
+ "Weak human-like signal", fill='black', font=font)
941
+
942
+ # weak non-human-like
943
+ draw.rectangle([
944
+ (legend_x + 250, legend_y - 45),
945
+ (legend_x + 250 + box_width, legend_y - 45 + box_height)
946
+ ], fill=(200, 200, 255, 255), outline=(0, 0, 0, 255))
947
+ draw.text((legend_x + 250 + box_width + spacing, legend_y - 45),
948
+ "Weak non-human-like signal", fill='black', font=font)
949
+
950
+ # strong non-human-like
951
+ draw.rectangle([
952
+ (legend_x + 250, legend_y - 20),
953
+ (legend_x + 250 + box_width, legend_y - 20 + box_height)
954
+ ], fill=(0, 0, 255, 255), outline=(0, 0, 0, 255))
955
+ draw.text((legend_x + 250 + box_width + spacing, legend_y - 20),
956
+ "Strong non-human-like signal", fill='black', font=font)
957
+
958
+ return img
959
+
960
 
961
 
962
  ###############################################################################