VenusFactory / src /data /get_esm3_structure_seq.py
2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import torch
import os
import sys
sys.path.append(os.getcwd())
import json
import argparse
import numpy as np
import biotite.structure.io as bsio
from tqdm import tqdm
from biotite.structure.io.pdb import PDBFile
from esm.utils.structure.protein_chain import ProteinChain
from esm.models.vqvae import StructureTokenEncoder
VQVAE_CODEBOOK_SIZE = 4096
VQVAE_SPECIAL_TOKENS = {
"MASK": VQVAE_CODEBOOK_SIZE,
"EOS": VQVAE_CODEBOOK_SIZE + 1,
"BOS": VQVAE_CODEBOOK_SIZE + 2,
"PAD": VQVAE_CODEBOOK_SIZE + 3,
"CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
}
def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
model = (
StructureTokenEncoder(
d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
)
.to(device)
.eval()
)
state_dict = torch.load(
"./src/data/weight/esm3_structure_encoder_v0.pth", map_location=device
)
model.load_state_dict(state_dict)
return model
def get_esm3_structure_seq(pdb_file, encoder, device="cuda:0"):
# Extract Unique Chain IDs
chain_ids = np.unique(PDBFile.read(pdb_file).get_structure().chain_id)
# print(chain_ids)
# ['L', 'H']
# By Default, ProteinChain takes first one
chain = ProteinChain.from_pdb(pdb_file, chain_id=chain_ids[0])
# Encoder
coords, plddt, residue_index = chain.to_structure_encoder_inputs()
coords = coords.to(device)
struct = bsio.load_structure(pdb_file, extra_fields=["b_factor"])
plddt = struct.b_factor.mean()
residue_index = residue_index.to(device)
_, structure_tokens = encoder.encode(coords, residue_index=residue_index)
result = {
'name': pdb_file.split('/')[-1].split('.')[0],
'esm3_structure_seq':structure_tokens.cpu().numpy().tolist()[0],
'plddt':plddt
}
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pdb_file", type=str, default=None)
parser.add_argument("--pdb_dir", type=str, default=None)
parser.add_argument("--out_file", type=str, default='esm3_structure_seq.json')
args = parser.parse_args()
device="cuda:0"
results = []
# result_dict = {'name':[], 'aa_seq':[], 'esm3_structure_seq':[], 'plddt':[], 'residue_index':[]}
encoder = ESM3_structure_encoder_v0(device)
if args.pdb_file is not None:
result = get_esm3_structure_seq(args.pdb_file, encoder, device)
results.append(result)
elif args.pdb_dir is not None:
pdb_files = os.listdir(args.pdb_dir)
for pdb_file in tqdm(pdb_files):
result = get_esm3_structure_seq(os.path.join(args.pdb_dir, pdb_file), encoder, device)
results.append(result)
with open(args.out_file, "w") as f:
f.write("\n".join([json.dumps(r) for r in results]))