hiyata commited on
Commit
9cb16e9
·
verified ·
1 Parent(s): 9308c12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +64 -11
app.py CHANGED
@@ -224,6 +224,28 @@ def compute_gc_content(sequence):
224
  # 7. MAIN ANALYSIS STEP (Gradio Step 1)
225
  ###############################################################################
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
228
  if fasta_text.strip():
229
  text = fasta_text.strip()
@@ -232,13 +254,13 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
232
  with open(file_obj, 'r') as f:
233
  text = f.read()
234
  except Exception as e:
235
- return (f"Error reading file: {str(e)}", None, None, None, None, None)
236
  else:
237
- return ("Please provide a FASTA sequence.", None, None, None, None, None)
238
 
239
  sequences = parse_fasta(text)
240
  if not sequences:
241
- return ("No valid FASTA sequences found.", None, None, None, None, None)
242
  header, seq = sequences[0]
243
 
244
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
@@ -249,7 +271,7 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
249
  model.load_state_dict(state_dict)
250
  scaler = joblib.load('scaler.pkl')
251
  except Exception as e:
252
- return (f"Error loading model/scaler: {str(e)}", None, None, None, None, None)
253
 
254
  freq_vector = sequence_to_kmer_vector(seq)
255
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
@@ -284,11 +306,13 @@ def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
284
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
285
  heatmap_img = fig_to_image(heatmap_fig)
286
 
287
- # You might want to provide a CSV or other data for the 6th return item
288
- # Here, we'll simply return None for the file download:
 
 
289
  state_dict_out = {"seq": seq, "shap_means": shap_means}
290
 
291
- return (results_text, bar_img, heatmap_img, state_dict_out, header, None)
292
 
293
  ###############################################################################
294
  # 8. SUBREGION ANALYSIS (Gradio Step 2)
@@ -963,9 +987,22 @@ def prepare_csv_download(data, filename="analysis_results.csv"):
963
  return output.getvalue().encode(), filename
964
  else:
965
  raise ValueError("Unsupported data type for CSV download")
 
 
 
 
 
 
 
 
 
 
 
 
 
966
 
967
  ###############################################################################
968
- # 13. BUILD GRADIO INTERFACE
969
  ###############################################################################
970
 
971
  css = """
@@ -993,6 +1030,10 @@ with gr.Blocks(css=css) as iface:
993
  with gr.Column(scale=1):
994
  file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
995
  text_input = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5)
 
 
 
 
996
  top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
997
  win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window size for 'most pushing' subregions")
998
  analyze_btn = gr.Button("Analyze Sequence", variant="primary")
@@ -1000,14 +1041,25 @@ with gr.Blocks(css=css) as iface:
1000
  results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
1001
  kmer_img = gr.Image(label="Top k-mer SHAP")
1002
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
1003
- download_results = gr.File(label="Download Results", visible=False, elem_classes="download-button")
 
 
 
 
1004
  seq_state = gr.State()
1005
  header_state = gr.State()
1006
 
 
 
 
 
 
 
 
1007
  analyze_btn.click(
1008
  analyze_sequence,
1009
  inputs=[file_input, top_k, text_input, win_size],
1010
- outputs=[results_box, kmer_img, genome_img, seq_state, header_state, download_results]
1011
  )
1012
 
1013
  with gr.Tab("2) Subregion Exploration"):
@@ -1114,8 +1166,9 @@ with gr.Blocks(css=css) as iface:
1114
  - Statistical summary of differences
1115
  - **Data Export**:
1116
  - Download results as CSV files
 
1117
  - Save analysis outputs for further processing
1118
  """)
1119
 
1120
  if __name__ == "__main__":
1121
- iface.launch()
 
224
  # 7. MAIN ANALYSIS STEP (Gradio Step 1)
225
  ###############################################################################
226
 
227
+ def create_kmer_shap_csv(kmers, shap_values):
228
+ """Create a CSV file with k-mer SHAP values and return the filepath"""
229
+ # Create DataFrame with k-mers and SHAP values
230
+ kmer_df = pd.DataFrame({
231
+ 'kmer': kmers,
232
+ 'shap_value': shap_values,
233
+ 'abs_shap': np.abs(shap_values)
234
+ })
235
+
236
+ # Sort by absolute SHAP value (most influential first)
237
+ kmer_df = kmer_df.sort_values('abs_shap', ascending=False)
238
+
239
+ # Drop the abs_shap column used for sorting
240
+ kmer_df = kmer_df[['kmer', 'shap_value']]
241
+
242
+ # Save to temporary file
243
+ temp_dir = tempfile.gettempdir()
244
+ temp_path = os.path.join(temp_dir, f"kmer_shap_values_{os.urandom(4).hex()}.csv")
245
+ kmer_df.to_csv(temp_path, index=False)
246
+
247
+ return temp_path
248
+
249
  def analyze_sequence(file_obj, top_kmers=10, fasta_text="", window_size=500):
250
  if fasta_text.strip():
251
  text = fasta_text.strip()
 
254
  with open(file_obj, 'r') as f:
255
  text = f.read()
256
  except Exception as e:
257
+ return (f"Error reading file: {str(e)}", None, None, None, None, None, None)
258
  else:
259
+ return ("Please provide a FASTA sequence.", None, None, None, None, None, None)
260
 
261
  sequences = parse_fasta(text)
262
  if not sequences:
263
+ return ("No valid FASTA sequences found.", None, None, None, None, None, None)
264
  header, seq = sequences[0]
265
 
266
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
271
  model.load_state_dict(state_dict)
272
  scaler = joblib.load('scaler.pkl')
273
  except Exception as e:
274
+ return (f"Error loading model/scaler: {str(e)}", None, None, None, None, None, None)
275
 
276
  freq_vector = sequence_to_kmer_vector(seq)
277
  scaled_vector = scaler.transform(freq_vector.reshape(1, -1))
 
306
  heatmap_fig = plot_linear_heatmap(shap_means, title="Genome-wide SHAP")
307
  heatmap_img = fig_to_image(heatmap_fig)
308
 
309
+ # Create CSV with k-mer SHAP values and return the file path
310
+ kmer_shap_csv = create_kmer_shap_csv(kmers, shap_values)
311
+
312
+ # State dictionary for subregion analysis
313
  state_dict_out = {"seq": seq, "shap_means": shap_means}
314
 
315
+ return (results_text, bar_img, heatmap_img, state_dict_out, header, None, kmer_shap_csv)
316
 
317
  ###############################################################################
318
  # 8. SUBREGION ANALYSIS (Gradio Step 2)
 
987
  return output.getvalue().encode(), filename
988
  else:
989
  raise ValueError("Unsupported data type for CSV download")
990
+
991
+ ###############################################################################
992
+ # 13. EXAMPLE FASTA LOADER
993
+ ###############################################################################
994
+
995
+ def load_example_fasta():
996
+ """Load the example.fasta file contents"""
997
+ try:
998
+ with open('example.fasta', 'r') as f:
999
+ example_text = f.read()
1000
+ return example_text
1001
+ except Exception as e:
1002
+ return f">example_sequence\nACGTACGT...\n\n(Note: Could not load example.fasta: {str(e)})"
1003
 
1004
  ###############################################################################
1005
+ # 14. BUILD GRADIO INTERFACE
1006
  ###############################################################################
1007
 
1008
  css = """
 
1030
  with gr.Column(scale=1):
1031
  file_input = gr.File(label="Upload FASTA file", file_types=[".fasta", ".fa", ".txt"], type="filepath")
1032
  text_input = gr.Textbox(label="Or paste FASTA sequence", placeholder=">sequence_name\nACGTACGT...", lines=5)
1033
+
1034
+ with gr.Row():
1035
+ example_btn = gr.Button("Load Example FASTA", variant="secondary")
1036
+
1037
  top_k = gr.Slider(minimum=5, maximum=30, value=10, step=1, label="Number of top k-mers to display")
1038
  win_size = gr.Slider(minimum=100, maximum=5000, value=500, step=100, label="Window size for 'most pushing' subregions")
1039
  analyze_btn = gr.Button("Analyze Sequence", variant="primary")
 
1041
  results_box = gr.Textbox(label="Classification Results", lines=12, interactive=False)
1042
  kmer_img = gr.Image(label="Top k-mer SHAP")
1043
  genome_img = gr.Image(label="Genome-wide SHAP Heatmap (Blue=neg, White=0, Red=pos)")
1044
+
1045
+ with gr.Row():
1046
+ download_kmer_shap = gr.File(label="Download k-mer SHAP Values (CSV)", visible=True)
1047
+ download_results = gr.File(label="Download Results", visible=False, elem_classes="download-button")
1048
+
1049
  seq_state = gr.State()
1050
  header_state = gr.State()
1051
 
1052
+ # Event handlers
1053
+ example_btn.click(
1054
+ load_example_fasta,
1055
+ inputs=[],
1056
+ outputs=[text_input]
1057
+ )
1058
+
1059
  analyze_btn.click(
1060
  analyze_sequence,
1061
  inputs=[file_input, top_k, text_input, win_size],
1062
+ outputs=[results_box, kmer_img, genome_img, seq_state, header_state, download_results, download_kmer_shap]
1063
  )
1064
 
1065
  with gr.Tab("2) Subregion Exploration"):
 
1166
  - Statistical summary of differences
1167
  - **Data Export**:
1168
  - Download results as CSV files
1169
+ - Download k-mer SHAP values
1170
  - Save analysis outputs for further processing
1171
  """)
1172
 
1173
  if __name__ == "__main__":
1174
+ iface.launch()