Spaces:
Runtime error
Runtime error
import os | |
import sys | |
sys.path.append(os.getcwd()) | |
import argparse | |
import json | |
import pandas as pd | |
import biotite.structure.io as bsio | |
from tqdm import tqdm | |
from Bio import PDB | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from src.utils.data_utils import extract_seq_from_pdb | |
ss_alphabet = ['H', 'E', 'C'] | |
ss_alphabet_dic = { | |
"H": "H", "G": "H", "E": "E", | |
"B": "E", "I": "C", "T": "C", | |
"S": "C", "L": "C", "-": "C", | |
"P": "C" | |
} | |
def get_secondary_structure_seq(pdb_file): | |
try: | |
# extract amino acid sequence | |
aa_seq = extract_seq_from_pdb(pdb_file) | |
pdb_parser = PDB.PDBParser(QUIET=True) | |
structure = pdb_parser.get_structure("protein", pdb_file) | |
model = structure[0] | |
dssp = PDB.DSSP(model, pdb_file) | |
# extract secondary structure sequence | |
sec_structures = [] | |
for i, dssp_res in enumerate(dssp): | |
sec_structures.append(dssp_res[2]) | |
except Exception as e: | |
return pdb_file, e | |
sec_structure_str_8 = ''.join(sec_structures) | |
sec_structure_str_8 = sec_structure_str_8.replace('-', 'L') | |
if len(aa_seq) != len(sec_structure_str_8): | |
return pdb_file, f"aa_seq {len(aa_seq)} and sec_structure_str_8 {len(sec_structure_str_8)} length mismatch" | |
sec_structure_str_3 = ''.join([ss_alphabet_dic[ss] for ss in sec_structures]) | |
final_dict = {} | |
final_dict["name"] = pdb_file.split('/')[-1].split('.')[0] | |
final_dict["aa_seq"] = aa_seq | |
final_dict["ss8_seq"] = sec_structure_str_8 | |
final_dict["ss3_seq"] = sec_structure_str_3 | |
return final_dict, None | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--pdb_dir', type=str, help='pdb dir') | |
parser.add_argument('--pdb_file', type=str, help='pdb file') | |
# multi processing | |
parser.add_argument('--num_workers', type=int, default=4, help='number of workers') | |
# index pdb for large scale inference | |
parser.add_argument("--pdb_index_file", default=None, type=str, help="pdb index file") | |
parser.add_argument("--pdb_index_level", default=1, type=int, help="pdb index level") | |
# save file | |
parser.add_argument('--error_file', type=str, help='save error file') | |
parser.add_argument('--out_file', type=str, help='save file') | |
args = parser.parse_args() | |
out_dir = os.path.dirname(args.out_file) | |
os.makedirs(out_dir, exist_ok=True) | |
if args.pdb_dir is not None: | |
# load pdb index file | |
if args.pdb_index_file: | |
pdbs = open(args.pdb_index_file).read().splitlines() | |
pdb_files = [] | |
for pdb in pdbs: | |
pdb_relative_dir = args.pdb_dir | |
for i in range(1, args.pdb_index_level+1): | |
pdb_relative_dir = os.path.join(pdb_relative_dir, pdb[:i]) | |
pdb_files.append(os.path.join(pdb_relative_dir, pdb+".pdb")) | |
# regular pdb dir | |
else: | |
pdb_files = sorted([os.path.join(args.pdb_dir, p) for p in os.listdir(args.pdb_dir)]) | |
results, error_pdbs, error_messages = [], [], [] | |
with ThreadPoolExecutor(max_workers=args.num_workers) as executor: | |
futures = [executor.submit(get_secondary_structure_seq, pdb_file) for pdb_file in pdb_files] | |
with tqdm(total=len(pdb_files), desc="Processing pdb") as progress: | |
for future in as_completed(futures): | |
result, message = future.result() | |
if message is None: | |
results.append(result) | |
else: | |
error_pdbs.append(result) | |
error_messages.append(message) | |
progress.update(1) | |
progress.close() | |
if error_pdbs: | |
if args.error_file is None: | |
args.error_file = args.out_file.split(".")[0]+"_error.csv" | |
error_dir = os.path.dirname(args.error_file) | |
os.makedirs(error_dir, exist_ok=True) | |
error_info = {"error_pdbs": error_pdbs, "error_messages": error_messages} | |
pd.DataFrame(error_info).to_csv(args.error_file, index=False) | |
with open(args.out_file, "w") as f: | |
f.write("\n".join([json.dumps(r) for r in results])) | |
elif args.pdb_file is not None: | |
result, message = get_secondary_structure_seq(args.pdb_file) | |
with open(args.out_file, "w") as f: | |
json.dump(result, f) | |