dsk129 commited on
Commit
4c8a467
·
verified ·
1 Parent(s): 041b838

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -19
app.py CHANGED
@@ -63,53 +63,63 @@ demo = gr.Interface(
63
  demo.launch()
64
  '''
65
 
66
- #--hydrophobic classification
67
- #--define hydrophobicity classification
 
 
 
 
 
 
 
 
 
 
68
  nonpolar = set("AFLIVMYW")
69
  polar = set("QERSDHKNT")
70
 
71
  def classify_residues(seq):
72
- return["nonpolar" if aa in nonpolar else "polar" if aa in polar else "other" for aa in seq]
73
 
74
  def compute_cosine_heatmap(seq):
75
- #--tokenize
76
  inputs = tokenizer(seq, return_tensors="pt")
77
  with torch.no_grad():
78
  outputs = model(**inputs)
79
- embedding = outputs.last_hidden_state[0]
80
 
81
- #--remove [CLS] and [EOS] if present
82
  L = len(seq)
83
  embedding = embedding[1:L+1]
84
 
85
- #--cosine similarity matrix
86
  sim_matrix = cosine_similarity(embedding.detach().cpu().numpy())
87
 
88
- #--residue classification
89
  residue_classes = classify_residues(seq)
90
  class_colors = {
91
- "nonpolor": "magenta",
92
- "polar": "indigo",
93
- "other": "steelblue"
94
  }
95
  row_colors = [class_colors[c] for c in residue_classes]
96
 
97
- #--plot heatmap
98
  fig, ax = plt.subplots(figsize=(8, 6))
99
  im = ax.imshow(sim_matrix, cmap="viridis")
100
- fig.colors(im, ax=ax, fraction=0.046, pad=0.04)
101
- ax.set_title("residue-residue cosine similarity")
102
- ax.set_xlabel("residue index")
103
- ax.set_ylabel("residue index")
104
 
105
- #--add colored ticks for class annotation
106
  for spine in ax.spines.values():
107
  spine.set_visible(False)
108
  ax.set_xticks(range(L))
109
  ax.set_yticks(range(L))
110
  ax.tick_params(length=0)
111
 
112
- #--color-code labels
113
  ax.set_xticklabels(residue_classes, rotation=90, fontsize=6)
114
  ax.set_yticklabels(residue_classes, fontsize=6)
115
  for label, color in zip(ax.get_xticklabels(), row_colors):
@@ -120,7 +130,7 @@ def compute_cosine_heatmap(seq):
120
  fig.tight_layout()
121
  return fig
122
 
123
- #--Gradio UI
124
  demo = gr.Interface(
125
  fn=compute_cosine_heatmap,
126
  inputs=gr.Textbox(label="Input Protein Sequence (1-letter code)"),
 
63
  demo.launch()
64
  '''
65
 
66
+ import torch
67
+ import gradio as gr
68
+ import matplotlib.pyplot as plt
69
+ import numpy as np
70
+ from sklearn.metrics.pairwise import cosine_similarity
71
+ from transformers import AutoTokenizer, EsmModel
72
+
73
+ # Load model
74
+ model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True)
75
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
76
+
77
+ # Define hydrophobicity classification
78
  nonpolar = set("AFLIVMYW")
79
  polar = set("QERSDHKNT")
80
 
81
  def classify_residues(seq):
82
+ return ["nonpolar" if aa in nonpolar else "polar" if aa in polar else "other" for aa in seq]
83
 
84
  def compute_cosine_heatmap(seq):
85
+ # Tokenize
86
  inputs = tokenizer(seq, return_tensors="pt")
87
  with torch.no_grad():
88
  outputs = model(**inputs)
89
+ embedding = outputs.last_hidden_state[0] # shape (L, 1280)
90
 
91
+ # Remove [CLS] and [EOS] if present
92
  L = len(seq)
93
  embedding = embedding[1:L+1]
94
 
95
+ # Cosine similarity matrix
96
  sim_matrix = cosine_similarity(embedding.detach().cpu().numpy())
97
 
98
+ # Residue classification
99
  residue_classes = classify_residues(seq)
100
  class_colors = {
101
+ "nonpolar": "magenta",
102
+ "polar": "indigo",
103
+ "other": "steelblue"
104
  }
105
  row_colors = [class_colors[c] for c in residue_classes]
106
 
107
+ # Plot heatmap
108
  fig, ax = plt.subplots(figsize=(8, 6))
109
  im = ax.imshow(sim_matrix, cmap="viridis")
110
+ fig.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
111
+ ax.set_title("Residue–Residue Cosine Similarity")
112
+ ax.set_xlabel("Residue Index")
113
+ ax.set_ylabel("Residue Index")
114
 
115
+ # Add colored ticks for class annotation
116
  for spine in ax.spines.values():
117
  spine.set_visible(False)
118
  ax.set_xticks(range(L))
119
  ax.set_yticks(range(L))
120
  ax.tick_params(length=0)
121
 
122
+ # Color-code labels
123
  ax.set_xticklabels(residue_classes, rotation=90, fontsize=6)
124
  ax.set_yticklabels(residue_classes, fontsize=6)
125
  for label, color in zip(ax.get_xticklabels(), row_colors):
 
130
  fig.tight_layout()
131
  return fig
132
 
133
+ # Gradio UI
134
  demo = gr.Interface(
135
  fn=compute_cosine_heatmap,
136
  inputs=gr.Textbox(label="Input Protein Sequence (1-letter code)"),