Spaces:
Runtime error
Runtime error
File size: 2,951 Bytes
8918ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import torch
import os
import sys
sys.path.append(os.getcwd())
import json
import argparse
import numpy as np
import biotite.structure.io as bsio
from tqdm import tqdm
from biotite.structure.io.pdb import PDBFile
from esm.utils.structure.protein_chain import ProteinChain
from esm.models.vqvae import StructureTokenEncoder
VQVAE_CODEBOOK_SIZE = 4096
VQVAE_SPECIAL_TOKENS = {
"MASK": VQVAE_CODEBOOK_SIZE,
"EOS": VQVAE_CODEBOOK_SIZE + 1,
"BOS": VQVAE_CODEBOOK_SIZE + 2,
"PAD": VQVAE_CODEBOOK_SIZE + 3,
"CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
}
def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
model = (
StructureTokenEncoder(
d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
)
.to(device)
.eval()
)
state_dict = torch.load(
"./src/data/weight/esm3_structure_encoder_v0.pth", map_location=device
)
model.load_state_dict(state_dict)
return model
def get_esm3_structure_seq(pdb_file, encoder, device="cuda:0"):
# Extract Unique Chain IDs
chain_ids = np.unique(PDBFile.read(pdb_file).get_structure().chain_id)
# print(chain_ids)
# ['L', 'H']
# By Default, ProteinChain takes first one
chain = ProteinChain.from_pdb(pdb_file, chain_id=chain_ids[0])
# Encoder
coords, plddt, residue_index = chain.to_structure_encoder_inputs()
coords = coords.to(device)
struct = bsio.load_structure(pdb_file, extra_fields=["b_factor"])
plddt = struct.b_factor.mean()
residue_index = residue_index.to(device)
_, structure_tokens = encoder.encode(coords, residue_index=residue_index)
result = {
'name': pdb_file.split('/')[-1].split('.')[0],
'esm3_structure_seq':structure_tokens.cpu().numpy().tolist()[0],
'plddt':plddt
}
return result
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--pdb_file", type=str, default=None)
parser.add_argument("--pdb_dir", type=str, default=None)
parser.add_argument("--out_file", type=str, default='esm3_structure_seq.json')
args = parser.parse_args()
device="cuda:0"
results = []
# result_dict = {'name':[], 'aa_seq':[], 'esm3_structure_seq':[], 'plddt':[], 'residue_index':[]}
encoder = ESM3_structure_encoder_v0(device)
if args.pdb_file is not None:
result = get_esm3_structure_seq(args.pdb_file, encoder, device)
results.append(result)
elif args.pdb_dir is not None:
pdb_files = os.listdir(args.pdb_dir)
for pdb_file in tqdm(pdb_files):
result = get_esm3_structure_seq(os.path.join(args.pdb_dir, pdb_file), encoder, device)
results.append(result)
with open(args.out_file, "w") as f:
f.write("\n".join([json.dumps(r) for r in results]))
|