esm_embeddings / app.py
dsk129's picture
Update app.py
aa3c5fe verified
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, EsmModel
from sklearn.decomposition import PCA
from Bio.PDB import PDBParser, PDBIO
import tempfile
import os
# Load ESM-1b model and tokenizer
model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
# Compute PCA and return scaled values for selected components
def compute_scaled_pca_scores(seq, components):
inputs = tokenizer(seq, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
embedding = outputs.last_hidden_state[0]
L = len(seq)
embedding = embedding[1:L+1] # remove CLS and EOS
pca = PCA(n_components=max(components) + 1)
pca_result = pca.fit_transform(embedding.detach().cpu().numpy())
scaled_components = []
for c in components:
selected = pca_result[:, c]
scaled = (selected - selected.min()) / (selected.max() - selected.min()) * 100
scaled_components.append(scaled)
return scaled_components
# Inject scores into B-factor column and save each PDB separately
def inject_bfactors_and_save(pdb_file, scores_list, component_indices):
parser = PDBParser(QUIET=True)
structure = parser.get_structure("prot", pdb_file.name)
output_paths = []
for scores, idx in zip(scores_list, component_indices):
i = 0
for model in structure:
for chain in model:
for residue in chain:
if i >= len(scores):
break
for atom in residue:
atom.bfactor = float(scores[i])
i += 1
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=f"_PC{idx}.pdb").name
io = PDBIO()
io.set_structure(structure)
io.save(out_path)
output_paths.append(out_path)
return output_paths
# Gradio interface logic
def process(seq, pdb_file, component_string):
try:
components = [int(c.strip()) for c in component_string.split(",") if c.strip().isdigit()]
except:
return "Error: Please input a comma-separated list of integers.", []
scores_list = compute_scaled_pca_scores(seq, components)
pdb_paths = inject_bfactors_and_save(pdb_file, scores_list, components)
return pdb_paths
# Gradio UI
demo = gr.Interface(
fn=process,
inputs=[
gr.Textbox(label="Input Protein Sequence (1-letter code)"),
gr.File(label="Upload PDB File", file_types=[".pdb"]),
gr.Textbox(label="Comma-separated PCA Components (e.g. 0,1,2)")
],
outputs=gr.File(label="Download PDBs with PCA Projections", file_types=[".pdb"], file_count="multiple"),
title="ESM-1b PCA Component Projection: Multi-PC Structural Mapping"
)
demo.launch()