import torch import numpy as np import gradio as gr import matplotlib.pyplot as plt from transformers import AutoTokenizer, EsmModel from sklearn.decomposition import PCA from Bio.PDB import PDBParser, PDBIO import tempfile import os # Load ESM-1b model and tokenizer model = EsmModel.from_pretrained("facebook/esm1b_t33_650M_UR50S", output_hidden_states=True) tokenizer = AutoTokenizer.from_pretrained("facebook/esm1b_t33_650M_UR50S") # Compute PCA and return scaled values for selected components def compute_scaled_pca_scores(seq, components): inputs = tokenizer(seq, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) embedding = outputs.last_hidden_state[0] L = len(seq) embedding = embedding[1:L+1] # remove CLS and EOS pca = PCA(n_components=max(components) + 1) pca_result = pca.fit_transform(embedding.detach().cpu().numpy()) scaled_components = [] for c in components: selected = pca_result[:, c] scaled = (selected - selected.min()) / (selected.max() - selected.min()) * 100 scaled_components.append(scaled) return scaled_components # Inject scores into B-factor column and save each PDB separately def inject_bfactors_and_save(pdb_file, scores_list, component_indices): parser = PDBParser(QUIET=True) structure = parser.get_structure("prot", pdb_file.name) output_paths = [] for scores, idx in zip(scores_list, component_indices): i = 0 for model in structure: for chain in model: for residue in chain: if i >= len(scores): break for atom in residue: atom.bfactor = float(scores[i]) i += 1 out_path = tempfile.NamedTemporaryFile(delete=False, suffix=f"_PC{idx}.pdb").name io = PDBIO() io.set_structure(structure) io.save(out_path) output_paths.append(out_path) return output_paths # Gradio interface logic def process(seq, pdb_file, component_string): try: components = [int(c.strip()) for c in component_string.split(",") if c.strip().isdigit()] except: return "Error: Please input a comma-separated list of integers.", [] scores_list = compute_scaled_pca_scores(seq, components) pdb_paths = inject_bfactors_and_save(pdb_file, scores_list, components) return pdb_paths # Gradio UI demo = gr.Interface( fn=process, inputs=[ gr.Textbox(label="Input Protein Sequence (1-letter code)"), gr.File(label="Upload PDB File", file_types=[".pdb"]), gr.Textbox(label="Comma-separated PCA Components (e.g. 0,1,2)") ], outputs=gr.File(label="Download PDBs with PCA Projections", file_types=[".pdb"], file_count="multiple"), title="ESM-1b PCA Component Projection: Multi-PC Structural Mapping" ) demo.launch()