VenusFactory / src /data /get_secondary_structure_seq.py
2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import os
import sys
sys.path.append(os.getcwd())
import argparse
import json
import pandas as pd
import biotite.structure.io as bsio
from tqdm import tqdm
from Bio import PDB
from concurrent.futures import ThreadPoolExecutor, as_completed
from src.utils.data_utils import extract_seq_from_pdb
ss_alphabet = ['H', 'E', 'C']
ss_alphabet_dic = {
"H": "H", "G": "H", "E": "E",
"B": "E", "I": "C", "T": "C",
"S": "C", "L": "C", "-": "C",
"P": "C"
}
def get_secondary_structure_seq(pdb_file):
try:
# extract amino acid sequence
aa_seq = extract_seq_from_pdb(pdb_file)
pdb_parser = PDB.PDBParser(QUIET=True)
structure = pdb_parser.get_structure("protein", pdb_file)
model = structure[0]
dssp = PDB.DSSP(model, pdb_file)
# extract secondary structure sequence
sec_structures = []
for i, dssp_res in enumerate(dssp):
sec_structures.append(dssp_res[2])
except Exception as e:
return pdb_file, e
sec_structure_str_8 = ''.join(sec_structures)
sec_structure_str_8 = sec_structure_str_8.replace('-', 'L')
if len(aa_seq) != len(sec_structure_str_8):
return pdb_file, f"aa_seq {len(aa_seq)} and sec_structure_str_8 {len(sec_structure_str_8)} length mismatch"
sec_structure_str_3 = ''.join([ss_alphabet_dic[ss] for ss in sec_structures])
final_dict = {}
final_dict["name"] = pdb_file.split('/')[-1].split('.')[0]
final_dict["aa_seq"] = aa_seq
final_dict["ss8_seq"] = sec_structure_str_8
final_dict["ss3_seq"] = sec_structure_str_3
return final_dict, None
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--pdb_dir', type=str, help='pdb dir')
parser.add_argument('--pdb_file', type=str, help='pdb file')
# multi processing
parser.add_argument('--num_workers', type=int, default=4, help='number of workers')
# index pdb for large scale inference
parser.add_argument("--pdb_index_file", default=None, type=str, help="pdb index file")
parser.add_argument("--pdb_index_level", default=1, type=int, help="pdb index level")
# save file
parser.add_argument('--error_file', type=str, help='save error file')
parser.add_argument('--out_file', type=str, help='save file')
args = parser.parse_args()
out_dir = os.path.dirname(args.out_file)
os.makedirs(out_dir, exist_ok=True)
if args.pdb_dir is not None:
# load pdb index file
if args.pdb_index_file:
pdbs = open(args.pdb_index_file).read().splitlines()
pdb_files = []
for pdb in pdbs:
pdb_relative_dir = args.pdb_dir
for i in range(1, args.pdb_index_level+1):
pdb_relative_dir = os.path.join(pdb_relative_dir, pdb[:i])
pdb_files.append(os.path.join(pdb_relative_dir, pdb+".pdb"))
# regular pdb dir
else:
pdb_files = sorted([os.path.join(args.pdb_dir, p) for p in os.listdir(args.pdb_dir)])
results, error_pdbs, error_messages = [], [], []
with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
futures = [executor.submit(get_secondary_structure_seq, pdb_file) for pdb_file in pdb_files]
with tqdm(total=len(pdb_files), desc="Processing pdb") as progress:
for future in as_completed(futures):
result, message = future.result()
if message is None:
results.append(result)
else:
error_pdbs.append(result)
error_messages.append(message)
progress.update(1)
progress.close()
if error_pdbs:
if args.error_file is None:
args.error_file = args.out_file.split(".")[0]+"_error.csv"
error_dir = os.path.dirname(args.error_file)
os.makedirs(error_dir, exist_ok=True)
error_info = {"error_pdbs": error_pdbs, "error_messages": error_messages}
pd.DataFrame(error_info).to_csv(args.error_file, index=False)
with open(args.out_file, "w") as f:
f.write("\n".join([json.dumps(r) for r in results]))
elif args.pdb_file is not None:
result, message = get_secondary_structure_seq(args.pdb_file)
with open(args.out_file, "w") as f:
json.dump(result, f)