File size: 6,770 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
import re
from typing import Dict, List, Any
from transformers import PreTrainedTokenizer
from dataclasses import dataclass

VQVAE_CODEBOOK_SIZE = 4096
VQVAE_SPECIAL_TOKENS = {
    "MASK": VQVAE_CODEBOOK_SIZE,
    "EOS": VQVAE_CODEBOOK_SIZE + 1,
    "BOS": VQVAE_CODEBOOK_SIZE + 2,
    "PAD": VQVAE_CODEBOOK_SIZE + 3,
    "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
}

@dataclass
class Collator:
    """Data collator class for protein sequences."""
    tokenizer: PreTrainedTokenizer
    max_length: int = None
    structure_seq: List[str] = None
    problem_type: str = 'classification'
    plm_model: str = None
    num_labels: int = None

    def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
        """Collate function for batching examples."""
        # Initialize lists to store sequences and labels
        if "ProSST" in self.plm_model:
            aa_seqs, labels, str_tokens = [], [], []
        else:
            aa_seqs, labels = [], []
        structure_seqs = {
            seq_type: [] for seq_type in (self.structure_seq or [])
        }

        # Process each example
        for e in examples:
            # Process sequences
            aa_seq = self.process_sequence(e["aa_seq"])
            aa_seqs.append(aa_seq)
            if "ProSST" in self.plm_model:
                stru_token = self.process_stru_tokens(e["prosst_stru_token"])
                str_tokens.append(stru_token)
            
            # Process structure sequences if needed
            for seq_type in structure_seqs:
                if seq_type == 'esm3_structure_seq':
                    processed_seq = self.process_esm3_structure_seq(e[seq_type])
                else:
                    processed_seq = self.process_sequence(e[seq_type])
                structure_seqs[seq_type].append(processed_seq)
            
            if self.problem_type == 'multi_label_classification':
                label_list = e['label'].split(',')
                e['label'] = [int(l) for l in label_list]
                binary_list = [0] * self.num_labels
                for index in e['label']:
                    binary_list[index] = 1
                e['label'] = binary_list
            # Process labels
            labels.append(e["label"])

        # Tokenize sequences
        if "ProSST" in self.plm_model:
            batch = self.tokenize_sequences(aa_seqs, structure_seqs, str_tokens)
        else:
            batch = self.tokenize_sequences(aa_seqs, structure_seqs)
        
        # Add labels to batch
        batch["label"] = torch.as_tensor(
            labels, 
            dtype=torch.float if self.problem_type == 'regression' else torch.long
        )

        return batch

    def process_sequence(self, seq: str) -> str:
        """Process sequence based on model type."""
        if 'prot_bert' in self.plm_model or "prot_t5" in self.plm_model:
            seq = " ".join(list(seq))
            seq = re.sub(r"[UZOB]", "X", seq)
        return seq

    def process_esm3_structure_seq(self, seq: List[int]) -> torch.Tensor:
        """Process ESM3 structure sequence."""
        return torch.tensor([VQVAE_SPECIAL_TOKENS["BOS"]] + seq + [VQVAE_SPECIAL_TOKENS["EOS"]])

    def process_stru_tokens(self, seq:List[int]) -> torch.Tensor:
        """Process ProSST structure token."""
        if isinstance(seq, str):
            seq_clean = seq.strip("[]").replace(" ","")
            tokens = list(map(int, seq_clean.split(','))) if seq_clean else []
        elif isinstance(seq, (list, tuple)):
            tokens = [int(x) for x in seq]
        stru_tokens = [int(num) for num in tokens]
        return torch.tensor(stru_tokens)
    
    def tokenize_sequences(

        self,

        aa_seqs: List[str],

        structure_seqs: Dict[str, List[str]],

        str_tokens: List[str] = None,

    ) -> Dict[str, torch.Tensor]:
        """Tokenize all sequences."""
        # Process amino acid sequences
        if "esm1b" in self.plm_model or "esm1v" in self.plm_model:
            self.max_length = 1022
        aa_encodings = self.tokenizer(
            aa_seqs,
            padding=True,
            truncation=True if self.max_length else False,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        aa_max_length = len(aa_encodings["input_ids"][0])
        padded_tokens = []
        if str_tokens:
            for tokens in str_tokens:
                struct_sequence =  [int(num) for num in tokens]
                padded_tokens.append(struct_sequence + [0] * (aa_max_length - len(struct_sequence) - 2))
            batch = {
                "aa_seq_input_ids": aa_encodings["input_ids"],
                "aa_seq_attention_mask": aa_encodings["attention_mask"],
                "aa_seq_stru_tokens": torch.tensor(padded_tokens, dtype=torch.long)
            }
        else:
            batch = {
            "aa_seq_input_ids": aa_encodings["input_ids"],
            "aa_seq_attention_mask": aa_encodings["attention_mask"]
        }
        
        # Process structure sequences if provided
        for seq_type, seqs in structure_seqs.items():
            if not seqs:
                continue
                
            if seq_type == 'esm3_structure_seq':
                # ESM3 structure sequences are already tokenized
                structure_tokens = torch.stack(seqs)
                # Pad sequences to max length
                max_len = max(len(seq) for seq in seqs)
                padded_tokens = torch.zeros(len(seqs), max_len, dtype=torch.long)
                attention_mask = torch.zeros(len(seqs), max_len, dtype=torch.long)
                
                for i, seq in enumerate(seqs):
                    seq_len = len(seq)
                    padded_tokens[i, :seq_len] = seq
                    attention_mask[i, :seq_len] = 1
                    
                batch[f"{seq_type}_input_ids"] = padded_tokens
                batch[f"{seq_type}_attention_mask"] = attention_mask
                
            else:
                # Tokenize other structure sequences
                structure_encodings = self.tokenizer(
                    seqs,
                    padding=True,
                    truncation=True if self.max_length else False,
                    max_length=self.max_length,
                    return_tensors="pt"
                )
                
                batch[f"{seq_type}_input_ids"] = structure_encodings["input_ids"]
                batch[f"{seq_type}_attention_mask"] = structure_encodings["attention_mask"]
        
        return batch