Spaces:
Running
Running
#-------------------------------------------------------libraries------------------------------------------------------------------------------------ | |
import torch | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from transformers import AutoTokenizer, EsmModel | |
from sklearn.decomposition import PCA | |
#----------------------------------------------------Analysis------------------------------------------------------------------------------------ | |
#--load model and tokenizer | |
model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True) | |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S") | |
#--task to execute | |
def extract_and_plot(seq, layer=-1): | |
#--preprocess sequence | |
inputs = tokenizer(seq, return_tensors="pt") | |
#--forward pass | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
hidden_states = outputs.hidden_states #--> tuple: (layer0, ..., layer_final) | |
#--select hidden state from specified layer | |
if layer == 1: | |
embedding = hidden_states[-1][0] #--> (seq_len, hidden_dim) | |
else: | |
embedding = hidden_states[layer][0] | |
#--PCA | |
pca = PCA(n_components=2) | |
coords = pca.fit_transform(embedding.numpy()) | |
#--plot | |
plt.figure(figsize=(6, 4)) | |
plt.scatter(coords[:, 0], coords[:, 1]) | |
plt.title(f"PCA of esm1b embeddings (layer {layer})") | |
plt.xlabel("PCA1") | |
plt.ylabel("PCA2") | |
plt.tight_layout() | |
return plt | |
demo = gr.Interface( | |
fn=extract_and_plot, | |
inputs=[ | |
gr.Textbox(label="Protein Sequence"), | |
gr.Slider(minimum=0, maximum=33, step=1, value=33, label="Layer (-1 = final)") | |
], | |
outputs=gr.Plot() | |
) | |
demo.launch() |