File size: 2,951 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import torch
import os
import sys
sys.path.append(os.getcwd())
import json
import argparse
import numpy as np
import biotite.structure.io as bsio
from tqdm import tqdm
from biotite.structure.io.pdb import PDBFile
from esm.utils.structure.protein_chain import ProteinChain
from esm.models.vqvae import StructureTokenEncoder

VQVAE_CODEBOOK_SIZE = 4096
VQVAE_SPECIAL_TOKENS = {
    "MASK": VQVAE_CODEBOOK_SIZE,
    "EOS": VQVAE_CODEBOOK_SIZE + 1,
    "BOS": VQVAE_CODEBOOK_SIZE + 2,
    "PAD": VQVAE_CODEBOOK_SIZE + 3,
    "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
}

def ESM3_structure_encoder_v0(device: torch.device | str = "cpu"):
    model = (
        StructureTokenEncoder(
            d_model=1024, n_heads=1, v_heads=128, n_layers=2, d_out=128, n_codes=4096
        )
        .to(device)
        .eval()
    )
    state_dict = torch.load(
        "./src/data/weight/esm3_structure_encoder_v0.pth", map_location=device
    )
    model.load_state_dict(state_dict)
    return model

def get_esm3_structure_seq(pdb_file, encoder, device="cuda:0"):
    # Extract Unique Chain IDs
    chain_ids = np.unique(PDBFile.read(pdb_file).get_structure().chain_id)
    # print(chain_ids)
    # ['L', 'H']

    # By Default, ProteinChain takes first one
    chain = ProteinChain.from_pdb(pdb_file, chain_id=chain_ids[0])

    # Encoder
    coords, plddt, residue_index = chain.to_structure_encoder_inputs()
    coords = coords.to(device)
    struct = bsio.load_structure(pdb_file, extra_fields=["b_factor"])
    plddt = struct.b_factor.mean()
    residue_index = residue_index.to(device)
    _, structure_tokens = encoder.encode(coords, residue_index=residue_index)
    
    result = {
        'name': pdb_file.split('/')[-1].split('.')[0], 
        'esm3_structure_seq':structure_tokens.cpu().numpy().tolist()[0], 
        'plddt':plddt
    }
    return result

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--pdb_file", type=str, default=None)
    parser.add_argument("--pdb_dir", type=str, default=None)
    parser.add_argument("--out_file", type=str, default='esm3_structure_seq.json')
    args = parser.parse_args()
    
    device="cuda:0"
    results = []
    # result_dict = {'name':[], 'aa_seq':[], 'esm3_structure_seq':[], 'plddt':[], 'residue_index':[]}
    
    encoder = ESM3_structure_encoder_v0(device)
    
    if args.pdb_file is not None:
        result = get_esm3_structure_seq(args.pdb_file, encoder, device)
        results.append(result)
    
    elif args.pdb_dir is not None:
        pdb_files = os.listdir(args.pdb_dir)
        for pdb_file in tqdm(pdb_files):
            result = get_esm3_structure_seq(os.path.join(args.pdb_dir, pdb_file), encoder, device)
            results.append(result)
            
    with open(args.out_file, "w") as f:
        f.write("\n".join([json.dumps(r) for r in results]))