Spaces:
Runtime error
Runtime error
import torch | |
import os | |
import sys | |
sys.path.append(os.getcwd()) | |
import json | |
import argparse | |
import numpy as np | |
import biotite.structure.io as bsio | |
from tqdm import tqdm | |
from biotite.structure.io.pdb import PDBFile | |
from esm.utils.structure.protein_chain import ProteinChain | |
from esm.models.vqvae import StructureTokenEncoder | |
VQVAE_CODEBOOK_SIZE = 4096 | |
VQVAE_SPECIAL_TOKENS = { | |
"MASK": VQVAE_CODEBOOK_SIZE, | |
"EOS": VQVAE_CODEBOOK_SIZE + 1, | |
"BOS": VQVAE_CODEBOOK_SIZE + 2, | |
"PAD": VQVAE_CODEBOOK_SIZE + 3, | |
"CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4, | |
} | |
def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"): | |
model = ( | |
StructureTokenEncoder( | |
d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096 | |
) | |
.to(device) | |
.eval() | |
) | |
state_dict = torch.load( | |
"./src/data/weight/esm3_structure_encoder_v0.pth", map_location=device | |
) | |
model.load_state_dict(state_dict) | |
return model | |
def get_esm3_structure_seq(pdb_file, encoder, device="cuda:0"): | |
# Extract Unique Chain IDs | |
chain_ids = np.unique(PDBFile.read(pdb_file).get_structure().chain_id) | |
# print(chain_ids) | |
# ['L', 'H'] | |
# By Default, ProteinChain takes first one | |
chain = ProteinChain.from_pdb(pdb_file, chain_id=chain_ids[0]) | |
# Encoder | |
coords, plddt, residue_index = chain.to_structure_encoder_inputs() | |
coords = coords.to(device) | |
struct = bsio.load_structure(pdb_file, extra_fields=["b_factor"]) | |
plddt = struct.b_factor.mean() | |
residue_index = residue_index.to(device) | |
_, structure_tokens = encoder.encode(coords, residue_index=residue_index) | |
result = { | |
'name': pdb_file.split('/')[-1].split('.')[0], | |
'esm3_structure_seq':structure_tokens.cpu().numpy().tolist()[0], | |
'plddt':plddt | |
} | |
return result | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--pdb_file", type=str, default=None) | |
parser.add_argument("--pdb_dir", type=str, default=None) | |
parser.add_argument("--out_file", type=str, default='esm3_structure_seq.json') | |
args = parser.parse_args() | |
device="cuda:0" | |
results = [] | |
# result_dict = {'name':[], 'aa_seq':[], 'esm3_structure_seq':[], 'plddt':[], 'residue_index':[]} | |
encoder = ESM3_structure_encoder_v0(device) | |
if args.pdb_file is not None: | |
result = get_esm3_structure_seq(args.pdb_file, encoder, device) | |
results.append(result) | |
elif args.pdb_dir is not None: | |
pdb_files = os.listdir(args.pdb_dir) | |
for pdb_file in tqdm(pdb_files): | |
result = get_esm3_structure_seq(os.path.join(args.pdb_dir, pdb_file), encoder, device) | |
results.append(result) | |
with open(args.out_file, "w") as f: | |
f.write("\n".join([json.dumps(r) for r in results])) | |