File size: 4,536 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import os
import sys
import argparse
import json
import pandas as pd
import torch
from tqdm import tqdm
from Bio import SeqIO
from concurrent.futures import ThreadPoolExecutor, as_completed
from src.data.prosst.structure.quantizer import PdbQuantizer
from src.utils.data_utils import extract_seq_from_pdb
import warnings
warnings.filterwarnings("ignore", category=Warning)
structure_vocab_size = 20
processor = PdbQuantizer(structure_vocab_size = structure_vocab_size)

def get_prosst_token(pdb_file):
    """Generate ProSST structure tokens for a PDB file"""
    try:
        # 提取氨基酸序列
        aa_seq = extract_seq_from_pdb(pdb_file)
        
        # 处理结构序列
        structure_result = processor(pdb_file)
        pdb_name = os.path.basename(pdb_file)
        # 验证数据结构
        if structure_vocab_size not in structure_result:
            raise ValueError(f"Missing structure key: {structure_vocab_size}")
        if pdb_name not in structure_result[structure_vocab_size]:
            raise ValueError(f"Missing PDB entry: {pdb_name}")
        
        struct_sequence = structure_result[structure_vocab_size][pdb_name]['struct']
        struct_sequence = [int(num) for num in struct_sequence]
        
        # 添加特殊标记 [1] + sequence + [2]
        structure_sequence_offset = [3 + num for num in struct_sequence]
        structure_input_ids = torch.tensor(
            [[1] + structure_sequence_offset + [2]], 
            dtype=torch.long
        )
        
        return {
            "name": os.path.basename(pdb_file).split('.')[0],
            "aa_seq": aa_seq,
            "struct_tokens": structure_input_ids[0].tolist()
        }, None
        
    except Exception as e:
        return pdb_file, f"{str(e)}"

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='ProSST structure token generator')
    parser.add_argument('--pdb_dir', type=str, help='Directory containing PDB files')
    parser.add_argument('--pdb_file', type=str, help='Single PDB file path')
    parser.add_argument('--num_workers', type=int, default=16, help='Number of parallel workers')
    parser.add_argument('--pdb_index_file', type=str, default=None, help='PDB index file for sharding')
    parser.add_argument('--pdb_index_level', type=int, default=1, help='Directory hierarchy depth')
    parser.add_argument('--error_file', type=str, help='Error log output path')
    parser.add_argument('--out_file', type=str, required=True, help='Output JSON file path')
    args = parser.parse_args()

    if args.pdb_dir is not None:
        # load pdb index file
        if args.pdb_index_file:            
            pdbs = open(args.pdb_index_file).read().splitlines()
            pdb_files = []
            for pdb in pdbs:
                pdb_relative_dir = args.pdb_dir
                for i in range(1, args.pdb_index_level+1):
                    pdb_relative_dir = os.path.join(pdb_relative_dir, pdb[:i])
                pdb_files.append(os.path.join(pdb_relative_dir, pdb+".pdb"))
        
        # regular pdb dir
        else:
            pdb_files = sorted([os.path.join(args.pdb_dir, p) for p in os.listdir(args.pdb_dir)])
           

        # 并行处理
        results, errors = [], []
        with ThreadPoolExecutor(max_workers=args.num_workers) as executor:
            futures = {executor.submit(get_prosst_token, f): f for f in pdb_files}
            with tqdm(total=len(futures), desc="Processing PDBs") as progress:
                for future in as_completed(futures):
                    result, error = future.result()
                    if error:
                        errors.append({"file": result, "error": error})
                    else:
                        results.append(result)
                    progress.update(1)

        if errors:
            error_path = args.error_file or args.out_file.replace('.json', '_errors.csv')
            pd.DataFrame(errors).to_csv(error_path, index=False)
            print(f"Encountered {len(errors)} errors. Saved to {error_path}")


        with open(args.out_file, 'w') as f:
            f.write('\n'.join(json.dumps(r) for r in results))


    elif args.pdb_file:
        result, error = get_prosst_token(args.pdb_file)
        if error:
            raise RuntimeError(f"Error processing {args.pdb_file}: {error}")
        with open(args.out_file, 'w') as f:
            json.dump(result, f)