import torch import os import sys sys.path.append(os.getcwd()) import json import argparse import numpy as np import biotite.structure.io as bsio from tqdm import tqdm from biotite.structure.io.pdb import PDBFile from esm.utils.structure.protein_chain import ProteinChain from esm.models.vqvae import StructureTokenEncoder VQVAE_CODEBOOK_SIZE = 4096 VQVAE_SPECIAL_TOKENS = { "MASK": VQVAE_CODEBOOK_SIZE, "EOS": VQVAE_CODEBOOK_SIZE + 1, "BOS": VQVAE_CODEBOOK_SIZE + 2, "PAD": VQVAE_CODEBOOK_SIZE + 3, "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4, } def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"): model = ( StructureTokenEncoder( d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096 ) .to(device) .eval() ) state_dict = torch.load( "./src/data/weight/esm3_structure_encoder_v0.pth", map_location=device ) model.load_state_dict(state_dict) return model def get_esm3_structure_seq(pdb_file, encoder, device="cuda:0"): # Extract Unique Chain IDs chain_ids = np.unique(PDBFile.read(pdb_file).get_structure().chain_id) # print(chain_ids) # ['L', 'H'] # By Default, ProteinChain takes first one chain = ProteinChain.from_pdb(pdb_file, chain_id=chain_ids[0]) # Encoder coords, plddt, residue_index = chain.to_structure_encoder_inputs() coords = coords.to(device) struct = bsio.load_structure(pdb_file, extra_fields=["b_factor"]) plddt = struct.b_factor.mean() residue_index = residue_index.to(device) _, structure_tokens = encoder.encode(coords, residue_index=residue_index) result = { 'name': pdb_file.split('/')[-1].split('.')[0], 'esm3_structure_seq':structure_tokens.cpu().numpy().tolist()[0], 'plddt':plddt } return result if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--pdb_file", type=str, default=None) parser.add_argument("--pdb_dir", type=str, default=None) parser.add_argument("--out_file", type=str, default='esm3_structure_seq.json') args = parser.parse_args() device="cuda:0" results = [] # result_dict = {'name':[], 'aa_seq':[], 'esm3_structure_seq':[], 'plddt':[], 'residue_index':[]} encoder = ESM3_structure_encoder_v0(device) if args.pdb_file is not None: result = get_esm3_structure_seq(args.pdb_file, encoder, device) results.append(result) elif args.pdb_dir is not None: pdb_files = os.listdir(args.pdb_dir) for pdb_file in tqdm(pdb_files): result = get_esm3_structure_seq(os.path.join(args.pdb_dir, pdb_file), encoder, device) results.append(result) with open(args.out_file, "w") as f: f.write("\n".join([json.dumps(r) for r in results]))