dsk129 commited on
Commit
09c54b0
·
verified ·
1 Parent(s): 4ac36ee

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -0
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #-------------------------------------------------------libraries------------------------------------------------------------------------------------
2
+ import torch
3
+ import gradio as gr
4
+ import matplotlib.pyplot as plt
5
+ from transformers import AutoTokenizer, EsmModel
6
+ from sklearn.decomposition import PCA
7
+
8
+ #----------------------------------------------------Analysis------------------------------------------------------------------------------------
9
+ #--load model and tokenizer
10
+ model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True)
11
+ tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
12
+
13
+ #--task to execute
14
+ def extract_and_plot(seq, layer=-1):
15
+ #--preprocess sequence
16
+ inputs = tokenizer(seq, return_tensors="pt")
17
+
18
+ #--forward pass
19
+ with torch.no_grad():
20
+ outputs = model(**inputs)
21
+ hidden_states = outputs.hidden_states #--> tuple: (layer0, ..., layer_final)
22
+
23
+ #--select hidden state from specified layer
24
+ if layer == 1:
25
+ embedding = hidden_states[-1][0] #--> (seq_len, hidden_dim)
26
+ else:
27
+ embedding = hidden_states[layer][0]
28
+
29
+ #--PCA
30
+ pca = PCA(n_components=2)
31
+ coords = pca.fit_transform(embedding.numpy())
32
+
33
+ #--plot
34
+ plt.figure(figsize=(6, 4))
35
+ plt.scatter(coords[:, 0], coords[:, 1])
36
+ plt.title(f"PCA of esm1b embeddings (layer {layer})")
37
+ plt.xlabel("PCA1")
38
+ plt.ylabel("PCA2")
39
+ plt.tight_layout()
40
+
41
+ return plt
42
+
43
+ demo = gr.Interface(
44
+ fn=extract_and_plot,
45
+ inputs=[
46
+ gr.Textbox(label="Protein Sequence"),
47
+ gr.Slider(minimum=0, maximum=33, step=1, value=33, label="Layer (-1 = final)")
48
+ ],
49
+ outputs=gr.Plot()
50
+ )
51
+
52
+ demo.launch()