VenusFactory / src /data /collator.py
2dogey's picture
Upload folder using huggingface_hub
8918ac7 verified
import torch
import re
from typing import Dict, List, Any
from transformers import PreTrainedTokenizer
from dataclasses import dataclass
VQVAE_CODEBOOK_SIZE = 4096
VQVAE_SPECIAL_TOKENS = {
"MASK": VQVAE_CODEBOOK_SIZE,
"EOS": VQVAE_CODEBOOK_SIZE + 1,
"BOS": VQVAE_CODEBOOK_SIZE + 2,
"PAD": VQVAE_CODEBOOK_SIZE + 3,
"CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4,
}
@dataclass
class Collator:
"""Data collator class for protein sequences."""
tokenizer: PreTrainedTokenizer
max_length: int = None
structure_seq: List[str] = None
problem_type: str = 'classification'
plm_model: str = None
num_labels: int = None
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
"""Collate function for batching examples."""
# Initialize lists to store sequences and labels
if "ProSST" in self.plm_model:
aa_seqs, labels, str_tokens = [], [], []
else:
aa_seqs, labels = [], []
structure_seqs = {
seq_type: [] for seq_type in (self.structure_seq or [])
}
# Process each example
for e in examples:
# Process sequences
aa_seq = self.process_sequence(e["aa_seq"])
aa_seqs.append(aa_seq)
if "ProSST" in self.plm_model:
stru_token = self.process_stru_tokens(e["prosst_stru_token"])
str_tokens.append(stru_token)
# Process structure sequences if needed
for seq_type in structure_seqs:
if seq_type == 'esm3_structure_seq':
processed_seq = self.process_esm3_structure_seq(e[seq_type])
else:
processed_seq = self.process_sequence(e[seq_type])
structure_seqs[seq_type].append(processed_seq)
if self.problem_type == 'multi_label_classification':
label_list = e['label'].split(',')
e['label'] = [int(l) for l in label_list]
binary_list = [0] * self.num_labels
for index in e['label']:
binary_list[index] = 1
e['label'] = binary_list
# Process labels
labels.append(e["label"])
# Tokenize sequences
if "ProSST" in self.plm_model:
batch = self.tokenize_sequences(aa_seqs, structure_seqs, str_tokens)
else:
batch = self.tokenize_sequences(aa_seqs, structure_seqs)
# Add labels to batch
batch["label"] = torch.as_tensor(
labels,
dtype=torch.float if self.problem_type == 'regression' else torch.long
)
return batch
def process_sequence(self, seq: str) -> str:
"""Process sequence based on model type."""
if 'prot_bert' in self.plm_model or "prot_t5" in self.plm_model:
seq = " ".join(list(seq))
seq = re.sub(r"[UZOB]", "X", seq)
return seq
def process_esm3_structure_seq(self, seq: List[int]) -> torch.Tensor:
"""Process ESM3 structure sequence."""
return torch.tensor([VQVAE_SPECIAL_TOKENS["BOS"]] + seq + [VQVAE_SPECIAL_TOKENS["EOS"]])
def process_stru_tokens(self, seq:List[int]) -> torch.Tensor:
"""Process ProSST structure token."""
if isinstance(seq, str):
seq_clean = seq.strip("[]").replace(" ","")
tokens = list(map(int, seq_clean.split(','))) if seq_clean else []
elif isinstance(seq, (list, tuple)):
tokens = [int(x) for x in seq]
stru_tokens = [int(num) for num in tokens]
return torch.tensor(stru_tokens)
def tokenize_sequences(
self,
aa_seqs: List[str],
structure_seqs: Dict[str, List[str]],
str_tokens: List[str] = None,
) -> Dict[str, torch.Tensor]:
"""Tokenize all sequences."""
# Process amino acid sequences
if "esm1b" in self.plm_model or "esm1v" in self.plm_model:
self.max_length = 1022
aa_encodings = self.tokenizer(
aa_seqs,
padding=True,
truncation=True if self.max_length else False,
max_length=self.max_length,
return_tensors="pt"
)
aa_max_length = len(aa_encodings["input_ids"][0])
padded_tokens = []
if str_tokens:
for tokens in str_tokens:
struct_sequence = [int(num) for num in tokens]
padded_tokens.append(struct_sequence + [0] * (aa_max_length - len(struct_sequence) - 2))
batch = {
"aa_seq_input_ids": aa_encodings["input_ids"],
"aa_seq_attention_mask": aa_encodings["attention_mask"],
"aa_seq_stru_tokens": torch.tensor(padded_tokens, dtype=torch.long)
}
else:
batch = {
"aa_seq_input_ids": aa_encodings["input_ids"],
"aa_seq_attention_mask": aa_encodings["attention_mask"]
}
# Process structure sequences if provided
for seq_type, seqs in structure_seqs.items():
if not seqs:
continue
if seq_type == 'esm3_structure_seq':
# ESM3 structure sequences are already tokenized
structure_tokens = torch.stack(seqs)
# Pad sequences to max length
max_len = max(len(seq) for seq in seqs)
padded_tokens = torch.zeros(len(seqs), max_len, dtype=torch.long)
attention_mask = torch.zeros(len(seqs), max_len, dtype=torch.long)
for i, seq in enumerate(seqs):
seq_len = len(seq)
padded_tokens[i, :seq_len] = seq
attention_mask[i, :seq_len] = 1
batch[f"{seq_type}_input_ids"] = padded_tokens
batch[f"{seq_type}_attention_mask"] = attention_mask
else:
# Tokenize other structure sequences
structure_encodings = self.tokenizer(
seqs,
padding=True,
truncation=True if self.max_length else False,
max_length=self.max_length,
return_tensors="pt"
)
batch[f"{seq_type}_input_ids"] = structure_encodings["input_ids"]
batch[f"{seq_type}_attention_mask"] = structure_encodings["attention_mask"]
return batch