import torch import re from typing import Dict, List, Any from transformers import PreTrainedTokenizer from dataclasses import dataclass VQVAE_CODEBOOK_SIZE = 4096 VQVAE_SPECIAL_TOKENS = { "MASK": VQVAE_CODEBOOK_SIZE, "EOS": VQVAE_CODEBOOK_SIZE + 1, "BOS": VQVAE_CODEBOOK_SIZE + 2, "PAD": VQVAE_CODEBOOK_SIZE + 3, "CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4, } @dataclass class Collator: """Data collator class for protein sequences.""" tokenizer: PreTrainedTokenizer max_length: int = None structure_seq: List[str] = None problem_type: str = 'classification' plm_model: str = None num_labels: int = None def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: """Collate function for batching examples.""" # Initialize lists to store sequences and labels if "ProSST" in self.plm_model: aa_seqs, labels, str_tokens = [], [], [] else: aa_seqs, labels = [], [] structure_seqs = { seq_type: [] for seq_type in (self.structure_seq or []) } # Process each example for e in examples: # Process sequences aa_seq = self.process_sequence(e["aa_seq"]) aa_seqs.append(aa_seq) if "ProSST" in self.plm_model: stru_token = self.process_stru_tokens(e["prosst_stru_token"]) str_tokens.append(stru_token) # Process structure sequences if needed for seq_type in structure_seqs: if seq_type == 'esm3_structure_seq': processed_seq = self.process_esm3_structure_seq(e[seq_type]) else: processed_seq = self.process_sequence(e[seq_type]) structure_seqs[seq_type].append(processed_seq) if self.problem_type == 'multi_label_classification': label_list = e['label'].split(',') e['label'] = [int(l) for l in label_list] binary_list = [0] * self.num_labels for index in e['label']: binary_list[index] = 1 e['label'] = binary_list # Process labels labels.append(e["label"]) # Tokenize sequences if "ProSST" in self.plm_model: batch = self.tokenize_sequences(aa_seqs, structure_seqs, str_tokens) else: batch = self.tokenize_sequences(aa_seqs, structure_seqs) # Add labels to batch batch["label"] = torch.as_tensor( labels, dtype=torch.float if self.problem_type == 'regression' else torch.long ) return batch def process_sequence(self, seq: str) -> str: """Process sequence based on model type.""" if 'prot_bert' in self.plm_model or "prot_t5" in self.plm_model: seq = " ".join(list(seq)) seq = re.sub(r"[UZOB]", "X", seq) return seq def process_esm3_structure_seq(self, seq: List[int]) -> torch.Tensor: """Process ESM3 structure sequence.""" return torch.tensor([VQVAE_SPECIAL_TOKENS["BOS"]] + seq + [VQVAE_SPECIAL_TOKENS["EOS"]]) def process_stru_tokens(self, seq:List[int]) -> torch.Tensor: """Process ProSST structure token.""" if isinstance(seq, str): seq_clean = seq.strip("[]").replace(" ","") tokens = list(map(int, seq_clean.split(','))) if seq_clean else [] elif isinstance(seq, (list, tuple)): tokens = [int(x) for x in seq] stru_tokens = [int(num) for num in tokens] return torch.tensor(stru_tokens) def tokenize_sequences( self, aa_seqs: List[str], structure_seqs: Dict[str, List[str]], str_tokens: List[str] = None, ) -> Dict[str, torch.Tensor]: """Tokenize all sequences.""" # Process amino acid sequences if "esm1b" in self.plm_model or "esm1v" in self.plm_model: self.max_length = 1022 aa_encodings = self.tokenizer( aa_seqs, padding=True, truncation=True if self.max_length else False, max_length=self.max_length, return_tensors="pt" ) aa_max_length = len(aa_encodings["input_ids"][0]) padded_tokens = [] if str_tokens: for tokens in str_tokens: struct_sequence = [int(num) for num in tokens] padded_tokens.append(struct_sequence + [0] * (aa_max_length - len(struct_sequence) - 2)) batch = { "aa_seq_input_ids": aa_encodings["input_ids"], "aa_seq_attention_mask": aa_encodings["attention_mask"], "aa_seq_stru_tokens": torch.tensor(padded_tokens, dtype=torch.long) } else: batch = { "aa_seq_input_ids": aa_encodings["input_ids"], "aa_seq_attention_mask": aa_encodings["attention_mask"] } # Process structure sequences if provided for seq_type, seqs in structure_seqs.items(): if not seqs: continue if seq_type == 'esm3_structure_seq': # ESM3 structure sequences are already tokenized structure_tokens = torch.stack(seqs) # Pad sequences to max length max_len = max(len(seq) for seq in seqs) padded_tokens = torch.zeros(len(seqs), max_len, dtype=torch.long) attention_mask = torch.zeros(len(seqs), max_len, dtype=torch.long) for i, seq in enumerate(seqs): seq_len = len(seq) padded_tokens[i, :seq_len] = seq attention_mask[i, :seq_len] = 1 batch[f"{seq_type}_input_ids"] = padded_tokens batch[f"{seq_type}_attention_mask"] = attention_mask else: # Tokenize other structure sequences structure_encodings = self.tokenizer( seqs, padding=True, truncation=True if self.max_length else False, max_length=self.max_length, return_tensors="pt" ) batch[f"{seq_type}_input_ids"] = structure_encodings["input_ids"] batch[f"{seq_type}_attention_mask"] = structure_encodings["attention_mask"] return batch