Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
from transformers import AutoTokenizer, EsmModel | |
from sklearn.decomposition import PCA | |
from Bio.PDB import PDBParser, PDBIO | |
import tempfile | |
import os | |
# Load ESM-1b model and tokenizer | |
model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True) | |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S") | |
# Compute PCA and return scaled values for selected components | |
def compute_scaled_pca_scores(seq, components): | |
inputs = tokenizer(seq, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embedding = outputs.last_hidden_state[0] | |
L = len(seq) | |
embedding = embedding[1:L+1] # remove CLS and EOS | |
pca = PCA(n_components=max(components) + 1) | |
pca_result = pca.fit_transform(embedding.detach().cpu().numpy()) | |
scaled_components = [] | |
for c in components: | |
selected = pca_result[:, c] | |
scaled = (selected - selected.min()) / (selected.max() - selected.min()) * 100 | |
scaled_components.append(scaled) | |
return scaled_components | |
# Inject scores into B-factor column and save each PDB separately | |
def inject_bfactors_and_save(pdb_file, scores_list, component_indices): | |
parser = PDBParser(QUIET=True) | |
structure = parser.get_structure("prot", pdb_file.name) | |
output_paths = [] | |
for scores, idx in zip(scores_list, component_indices): | |
i = 0 | |
for model in structure: | |
for chain in model: | |
for residue in chain: | |
if i >= len(scores): | |
break | |
for atom in residue: | |
atom.bfactor = float(scores[i]) | |
i += 1 | |
out_path = tempfile.NamedTemporaryFile(delete=False, suffix=f"_PC{idx}.pdb").name | |
io = PDBIO() | |
io.set_structure(structure) | |
io.save(out_path) | |
output_paths.append(out_path) | |
return output_paths | |
# Gradio interface logic | |
def process(seq, pdb_file, component_string): | |
try: | |
components = [int(c.strip()) for c in component_string.split(",") if c.strip().isdigit()] | |
except: | |
return "Error: Please input a comma-separated list of integers.", [] | |
scores_list = compute_scaled_pca_scores(seq, components) | |
pdb_paths = inject_bfactors_and_save(pdb_file, scores_list, components) | |
return pdb_paths | |
# Gradio UI | |
demo = gr.Interface( | |
fn=process, | |
inputs=[ | |
gr.Textbox(label="Input Protein Sequence (1-letter code)"), | |
gr.File(label="Upload PDB File", file_types=[".pdb"]), | |
gr.Textbox(label="Comma-separated PCA Components (e.g. 0,1,2)") | |
], | |
outputs=gr.File(label="Download PDBs with PCA Projections", file_types=[".pdb"], file_count="multiple"), | |
title="ESM-1b PCA Component Projection: Multi-PC Structural Mapping" | |
) | |
demo.launch() | |