Spaces:
Runtime error
Runtime error
import os | |
import sys | |
import argparse | |
import json | |
import pandas as pd | |
import torch | |
from tqdm import tqdm | |
from Bio import SeqIO | |
from concurrent.futures import ThreadPoolExecutor, as_completed | |
from src.data.prosst.structure.quantizer import PdbQuantizer | |
from src.utils.data_utils import extract_seq_from_pdb | |
import warnings | |
warnings.filterwarnings("ignore", category=Warning) | |
structure_vocab_size = 20 | |
processor = PdbQuantizer(structure_vocab_size = structure_vocab_size) | |
def get_prosst_token(pdb_file): | |
"""Generate ProSST structure tokens for a PDB file""" | |
try: | |
# 提取氨基酸序列 | |
aa_seq = extract_seq_from_pdb(pdb_file) | |
# 处理结构序列 | |
structure_result = processor(pdb_file) | |
pdb_name = os.path.basename(pdb_file) | |
# 验证数据结构 | |
if structure_vocab_size not in structure_result: | |
raise ValueError(f"Missing structure key: {structure_vocab_size}") | |
if pdb_name not in structure_result[structure_vocab_size]: | |
raise ValueError(f"Missing PDB entry: {pdb_name}") | |
struct_sequence = structure_result[structure_vocab_size][pdb_name]['struct'] | |
struct_sequence = [int(num) for num in struct_sequence] | |
# 添加特殊标记 [1] + sequence + [2] | |
structure_sequence_offset = [3 + num for num in struct_sequence] | |
structure_input_ids = torch.tensor( | |
[[1] + structure_sequence_offset + [2]], | |
dtype=torch.long | |
) | |
return { | |
"name": os.path.basename(pdb_file).split('.')[0], | |
"aa_seq": aa_seq, | |
"struct_tokens": structure_input_ids[0].tolist() | |
}, None | |
except Exception as e: | |
return pdb_file, f"{str(e)}" | |
if __name__ == '__main__': | |
parser = argparse.ArgumentParser(description='ProSST structure token generator') | |
parser.add_argument('--pdb_dir', type=str, help='Directory containing PDB files') | |
parser.add_argument('--pdb_file', type=str, help='Single PDB file path') | |
parser.add_argument('--num_workers', type=int, default=16, help='Number of parallel workers') | |
parser.add_argument('--pdb_index_file', type=str, default=None, help='PDB index file for sharding') | |
parser.add_argument('--pdb_index_level', type=int, default=1, help='Directory hierarchy depth') | |
parser.add_argument('--error_file', type=str, help='Error log output path') | |
parser.add_argument('--out_file', type=str, required=True, help='Output JSON file path') | |
args = parser.parse_args() | |
if args.pdb_dir is not None: | |
# load pdb index file | |
if args.pdb_index_file: | |
pdbs = open(args.pdb_index_file).read().splitlines() | |
pdb_files = [] | |
for pdb in pdbs: | |
pdb_relative_dir = args.pdb_dir | |
for i in range(1, args.pdb_index_level+1): | |
pdb_relative_dir = os.path.join(pdb_relative_dir, pdb[:i]) | |
pdb_files.append(os.path.join(pdb_relative_dir, pdb+".pdb")) | |
# regular pdb dir | |
else: | |
pdb_files = sorted([os.path.join(args.pdb_dir, p) for p in os.listdir(args.pdb_dir)]) | |
# 并行处理 | |
results, errors = [], [] | |
with ThreadPoolExecutor(max_workers=args.num_workers) as executor: | |
futures = {executor.submit(get_prosst_token, f): f for f in pdb_files} | |
with tqdm(total=len(futures), desc="Processing PDBs") as progress: | |
for future in as_completed(futures): | |
result, error = future.result() | |
if error: | |
errors.append({"file": result, "error": error}) | |
else: | |
results.append(result) | |
progress.update(1) | |
if errors: | |
error_path = args.error_file or args.out_file.replace('.json', '_errors.csv') | |
pd.DataFrame(errors).to_csv(error_path, index=False) | |
print(f"Encountered {len(errors)} errors. Saved to {error_path}") | |
with open(args.out_file, 'w') as f: | |
f.write('\n'.join(json.dumps(r) for r in results)) | |
elif args.pdb_file: | |
result, error = get_prosst_token(args.pdb_file) | |
if error: | |
raise RuntimeError(f"Error processing {args.pdb_file}: {error}") | |
with open(args.out_file, 'w') as f: | |
json.dump(result, f) |