Spaces:
Runtime error
Runtime error
File size: 4,536 Bytes
8918ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 |
import os
import sys
import argparse
import json
import pandas as pd
import torch
from tqdm import tqdm
from Bio import SeqIO
from concurrent.futures import ThreadPoolExecutor, as_completed
from src.data.prosst.structure.quantizer import PdbQuantizer
from src.utils.data_utils import extract_seq_from_pdb
import warnings
warnings.filterwarnings("ignore", category=Warning)
structure_vocab_size = 20
processor = PdbQuantizer(structure_vocab_size = structure_vocab_size)
def get_prosst_token(pdb_file):
"""Generate ProSST structure tokens for a PDB file"""
try:
# 提取氨基酸序列
aa_seq = extract_seq_from_pdb(pdb_file)
# 处理结构序列
structure_result = processor(pdb_file)
pdb_name = os.path.basename(pdb_file)
# 验证数据结构
if structure_vocab_size not in structure_result:
raise ValueError(f"Missing structure key: {structure_vocab_size}")
if pdb_name not in structure_result[structure_vocab_size]:
raise ValueError(f"Missing PDB entry: {pdb_name}")
struct_sequence = structure_result[structure_vocab_size][pdb_name]['struct']
struct_sequence = [int(num) for num in struct_sequence]
# 添加特殊标记 [1] + sequence + [2]
structure_sequence_offset = [3 + num for num in struct_sequence]
structure_input_ids = torch.tensor(
[[1] + structure_sequence_offset + [2]],
dtype=torch.long
)
return {
"name": os.path.basename(pdb_file).split('.')[0],
"aa_seq": aa_seq,
"struct_tokens": structure_input_ids[0].tolist()
}, None
except Exception as e:
return pdb_file, f"{str(e)}"
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='ProSST structure token generator')
parser.add_argument('--pdb_dir', type=str, help='Directory containing PDB files')
parser.add_argument('--pdb_file', type=str, help='Single PDB file path')
parser.add_argument('--num_workers', type=int, default=16, help='Number of parallel workers')
parser.add_argument('--pdb_index_file', type=str, default=None, help='PDB index file for sharding')
parser.add_argument('--pdb_index_level', type=int, default=1, help='Directory hierarchy depth')
parser.add_argument('--error_file', type=str, help='Error log output path')
parser.add_argument('--out_file', type=str, required=True, help='Output JSON file path')
args = parser.parse_args()
if args.pdb_dir is not None:
# load pdb index file
if args.pdb_index_file:
pdbs = open(args.pdb_index_file).read().splitlines()
pdb_files = []
for pdb in pdbs:
pdb_relative_dir = args.pdb_dir
for i in range(1, args.pdb_index_level+1):
pdb_relative_dir = os.path.join(pdb_relative_dir, pdb[:i])
pdb_files.append(os.path.join(pdb_relative_dir, pdb+".pdb"))
# regular pdb dir
else:
pdb_files = sorted([os.path.join(args.pdb_dir, p) for p in os.listdir(args.pdb_dir)])
# 并行处理
results, errors = [], []
with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
futures = {executor.submit(get_prosst_token, f): f for f in pdb_files}
with tqdm(total=len(futures), desc="Processing PDBs") as progress:
for future in as_completed(futures):
result, error = future.result()
if error:
errors.append({"file": result, "error": error})
else:
results.append(result)
progress.update(1)
if errors:
error_path = args.error_file or args.out_file.replace('.json', '_errors.csv')
pd.DataFrame(errors).to_csv(error_path, index=False)
print(f"Encountered {len(errors)} errors. Saved to {error_path}")
with open(args.out_file, 'w') as f:
f.write('\n'.join(json.dumps(r) for r in results))
elif args.pdb_file:
result, error = get_prosst_token(args.pdb_file)
if error:
raise RuntimeError(f"Error processing {args.pdb_file}: {error}")
with open(args.out_file, 'w') as f:
json.dump(result, f) |