dsk129 commited on
Commit
8f25d3c
·
verified ·
1 Parent(s): ee0eed8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +71 -3
app.py CHANGED
@@ -1,10 +1,11 @@
1
  #-------------------------------------------------------libraries------------------------------------------------------------------------------------
2
  import torch
3
- import numpy
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  from transformers import AutoTokenizer, EsmModel
7
  from sklearn.decomposition import PCA
 
8
 
9
  #----------------------------------------------------Analysis------------------------------------------------------------------------------------
10
  #--load model and tokenizer
@@ -19,8 +20,8 @@ print("Transformers version:", transformers.__version__)
19
  #import torch
20
  print("Torch NumPy test:", torch.ones(1).numpy())
21
 
22
-
23
- #--task to execute
24
  def extract_and_plot(seq, layer=-1):
25
  #--preprocess sequence
26
  inputs = tokenizer(seq, return_tensors="pt")
@@ -59,4 +60,71 @@ demo = gr.Interface(
59
  outputs=gr.Plot()
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  demo.launch()
 
1
  #-------------------------------------------------------libraries------------------------------------------------------------------------------------
2
  import torch
3
+ import numpy as np
4
  import gradio as gr
5
  import matplotlib.pyplot as plt
6
  from transformers import AutoTokenizer, EsmModel
7
  from sklearn.decomposition import PCA
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
 
10
  #----------------------------------------------------Analysis------------------------------------------------------------------------------------
11
  #--load model and tokenizer
 
20
  #import torch
21
  print("Torch NumPy test:", torch.ones(1).numpy())
22
 
23
+ '''
24
+ #--principal component plot
25
  def extract_and_plot(seq, layer=-1):
26
  #--preprocess sequence
27
  inputs = tokenizer(seq, return_tensors="pt")
 
60
  outputs=gr.Plot()
61
  )
62
 
63
+ demo.launch()
64
+ '''
65
+
66
+ #--hydrophobic classification
67
+ #--define hydrophobicity classification
68
+ nonpolar = set("AFLIVMYW")
69
+ polar = set("QERSDHKNT")
70
+
71
+ def classify_residues(seq):
72
+ return["nonpolar" if aa in nonpolar else "polar" if aa in polar else "other" for aa in seq]
73
+
74
+ def compute_cosine_heatmap(seq):
75
+ #--tokenize
76
+ inputs = tokenizer(seq, reuturn_tenors="pt")
77
+ with torch.no_grad():
78
+ outputs = model(**inputs)
79
+ embedding = outputs.last_hidden_state[0]
80
+
81
+ #--remove [CLS] and [EOS] if present
82
+ L = len(seq)
83
+ embedding = embedding[1:L+1]
84
+
85
+ #--cosine similarity matrix
86
+ sim_matrix = cosine_similarity(embedding.detach().cpu().numpy())
87
+
88
+ #--residue classification
89
+ residue_classes = classify_residues(seq)
90
+ class_colors = {
91
+ "nonpolor": "magenta",
92
+ "polar": "indigo",
93
+ "other": "steelblue"
94
+ }
95
+ row_colors = [class_colors[c] for c in residues_classes]
96
+
97
+ #--plot heatmap
98
+ fig, ax = plt.subplots(figsize=(8, 6))
99
+ im = ax.imshow(sim_matrix, cmap="viridis")
100
+ fig.colors(im, ax=ax, fraction=0.046, pad=0.04)
101
+ ax.set_title("residue-residue cosine similarity")
102
+ ax.set_xlabel("residue index")
103
+ ax.set_ylabel("residue index")
104
+
105
+ #--add colored ticks for class annotation
106
+ for spine in ax.spines.values():
107
+ spine.set_visible(False)
108
+ ax.set_xticks(range(L))
109
+ ax.set_yticks(range(L))
110
+ ax.tick_params(length=0)
111
+
112
+ #--color-code labels
113
+ ax.set_xticklabels(residue_classes, rotation=90, fontsize=6)
114
+ ax.set_yticklabels(residue_classes, fontsize=6)
115
+ for label, color in zip(ax.get_xticklabels(), row_colors):
116
+ label.set_color(color)
117
+ for label, color in zip(ax.get_yticklabels(), row_colors):
118
+ label.set_color(color)
119
+
120
+ fig.tight_layout()
121
+ return fig
122
+
123
+ #--Gradio UI
124
+ demo = gr.Interface(
125
+ fn=compute_cosine_heatmap,
126
+ inputs=gr.Textbox(label="Input Protein Sequence (1-letter code)"),
127
+ outputs=gr.Plot()
128
+ )
129
+
130
  demo.launch()