File size: 4,659 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import os
import sys
sys.path.append(os.getcwd())
import argparse
import json
import pandas as pd
import biotite.structure.io as bsio
from tqdm import tqdm
from Bio import PDB
from concurrent.futures import ThreadPoolExecutor, as_completed
from src.utils.data_utils import extract_seq_from_pdb


ss_alphabet = ['H', 'E', 'C']
ss_alphabet_dic = {
    "H": "H", "G": "H", "E": "E",
    "B": "E", "I": "C", "T": "C",
    "S": "C", "L": "C", "-": "C",
    "P": "C"
}

def get_secondary_structure_seq(pdb_file):
    try:
        # extract amino acid sequence
        aa_seq = extract_seq_from_pdb(pdb_file)
        pdb_parser = PDB.PDBParser(QUIET=True)
        structure = pdb_parser.get_structure("protein", pdb_file)
        model = structure[0]
        dssp = PDB.DSSP(model, pdb_file)
        # extract secondary structure sequence
        sec_structures = []
        for i, dssp_res in enumerate(dssp):
            sec_structures.append(dssp_res[2])
    
    except Exception as e:
        return pdb_file, e
    
    sec_structure_str_8 = ''.join(sec_structures)
    sec_structure_str_8 = sec_structure_str_8.replace('-', 'L')
    if len(aa_seq) != len(sec_structure_str_8):
        return pdb_file, f"aa_seq {len(aa_seq)} and sec_structure_str_8 {len(sec_structure_str_8)} length mismatch"
    
    sec_structure_str_3 = ''.join([ss_alphabet_dic[ss] for ss in sec_structures])
    
    final_dict = {}
    final_dict["name"] = pdb_file.split('/')[-1].split('.')[0]
    final_dict["aa_seq"] = aa_seq
    final_dict["ss8_seq"] = sec_structure_str_8
    final_dict["ss3_seq"] = sec_structure_str_3
    
    return final_dict, None

if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--pdb_dir', type=str, help='pdb dir')
    parser.add_argument('--pdb_file', type=str, help='pdb file')
    
    # multi processing
    parser.add_argument('--num_workers', type=int, default=4, help='number of workers')
    
    # index pdb for large scale inference
    parser.add_argument("--pdb_index_file", default=None, type=str, help="pdb index file")
    parser.add_argument("--pdb_index_level", default=1, type=int, help="pdb index level")
    
    # save file
    parser.add_argument('--error_file', type=str, help='save error file')
    parser.add_argument('--out_file', type=str, help='save file')
    args = parser.parse_args()
    
    out_dir = os.path.dirname(args.out_file)
    os.makedirs(out_dir, exist_ok=True)
    
    if args.pdb_dir is not None:
        # load pdb index file
        if args.pdb_index_file:            
            pdbs = open(args.pdb_index_file).read().splitlines()
            pdb_files = []
            for pdb in pdbs:
                pdb_relative_dir = args.pdb_dir
                for i in range(1, args.pdb_index_level+1):
                    pdb_relative_dir = os.path.join(pdb_relative_dir, pdb[:i])
                pdb_files.append(os.path.join(pdb_relative_dir, pdb+".pdb"))
        
        # regular pdb dir
        else:
            pdb_files = sorted([os.path.join(args.pdb_dir, p) for p in os.listdir(args.pdb_dir)])
            
        results, error_pdbs, error_messages = [], [], []
        with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
            futures = [executor.submit(get_secondary_structure_seq, pdb_file) for pdb_file in pdb_files]

            with tqdm(total=len(pdb_files), desc="Processing pdb") as progress:
                for future in as_completed(futures):
                    result, message = future.result()
                    if message is None:
                        results.append(result)
                    else:
                        error_pdbs.append(result)
                        error_messages.append(message)
                    progress.update(1)
            progress.close()
        
        if error_pdbs:
            if args.error_file is None:
                args.error_file = args.out_file.split(".")[0]+"_error.csv"
            error_dir = os.path.dirname(args.error_file)
            os.makedirs(error_dir, exist_ok=True)
            error_info = {"error_pdbs": error_pdbs, "error_messages": error_messages}
            pd.DataFrame(error_info).to_csv(args.error_file, index=False)
        
        with open(args.out_file, "w") as f:
            f.write("\n".join([json.dumps(r) for r in results]))
    
    elif args.pdb_file is not None:
        result, message = get_secondary_structure_seq(args.pdb_file)
        with open(args.out_file, "w") as f:
            json.dump(result, f)