esm_embeddings / app.py
dsk129's picture
Update app.py
14b630e verified
raw
history blame
2.52 kB
#-------------------------------------------------libraries---------------------------------------------------------------------------------------------------------------
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, EsmModel
from sklearn.metrics.pairwise import cosine_similarity
from Bio.PDB import PDBParser, PDBIO
from Bio.PDB.StructureBuilder import StructureBuilder
import tempfile
import os
#----------------------------------------------------Analysis--------------------------------------------------------------------------------------------------------
# Load ESM-1b model and tokenizer
model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True)
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S")
def compute_residue_scores(seq):
inputs = tokenizer(seq, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
embedding = outputs.last_hidden_state[0] # shape (L+2, d)
L = len(seq)
embedding = embedding[1:L+1] # shape (L, d)
sim_matrix = cosine_similarity(embedding.detach().cpu().numpy())
residue_scores = np.sum(sim_matrix, axis=1)
norm_scores = 100 * (residue_scores - np.min(residue_scores)) / (np.max(residue_scores) - np.min(residue_scores))
return norm_scores
def inject_bfactors_into_pdb(pdb_file, scores):
parser = PDBParser(QUIET=True)
structure = parser.get_structure("prot", pdb_file.name)
i = 0
for model in structure:
for chain in model:
for residue in chain:
if i >= len(scores):
break
for atom in residue:
atom.bfactor = float(scores[i])
i += 1
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=".pdb").name
io = PDBIO()
io.set_structure(structure)
io.save(out_path)
return out_path
def process(seq, pdb_file):
scores = compute_residue_scores(seq)
pdb_with_scores = inject_bfactors_into_pdb(pdb_file, scores)
return pdb_with_scores
# Gradio Interface
demo = gr.Interface(
fn=process,
inputs=[
gr.Textbox(label="Input Protein Sequence (1-letter code)"),
gr.File(label="Upload PDB File", file_types=[".pdb"])
],
outputs=gr.File(label="Modified PDB with Scores in B-factor Column"),
title="ESM-1b Residue Scoring: B-factor Injection for Structural Visualization"
)
demo.launch()