Spaces:
Runtime error
Runtime error
import torch | |
import re | |
from typing import Dict, List, Any | |
from transformers import PreTrainedTokenizer | |
from dataclasses import dataclass | |
VQVAE_CODEBOOK_SIZE = 4096 | |
VQVAE_SPECIAL_TOKENS = { | |
"MASK": VQVAE_CODEBOOK_SIZE, | |
"EOS": VQVAE_CODEBOOK_SIZE + 1, | |
"BOS": VQVAE_CODEBOOK_SIZE + 2, | |
"PAD": VQVAE_CODEBOOK_SIZE + 3, | |
"CHAINBREAK": VQVAE_CODEBOOK_SIZE + 4, | |
} | |
class Collator: | |
"""Data collator class for protein sequences.""" | |
tokenizer: PreTrainedTokenizer | |
max_length: int = None | |
structure_seq: List[str] = None | |
problem_type: str = 'classification' | |
plm_model: str = None | |
num_labels: int = None | |
def __call__(self, examples: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]: | |
"""Collate function for batching examples.""" | |
# Initialize lists to store sequences and labels | |
if "ProSST" in self.plm_model: | |
aa_seqs, labels, str_tokens = [], [], [] | |
else: | |
aa_seqs, labels = [], [] | |
structure_seqs = { | |
seq_type: [] for seq_type in (self.structure_seq or []) | |
} | |
# Process each example | |
for e in examples: | |
# Process sequences | |
aa_seq = self.process_sequence(e["aa_seq"]) | |
aa_seqs.append(aa_seq) | |
if "ProSST" in self.plm_model: | |
stru_token = self.process_stru_tokens(e["prosst_stru_token"]) | |
str_tokens.append(stru_token) | |
# Process structure sequences if needed | |
for seq_type in structure_seqs: | |
if seq_type == 'esm3_structure_seq': | |
processed_seq = self.process_esm3_structure_seq(e[seq_type]) | |
else: | |
processed_seq = self.process_sequence(e[seq_type]) | |
structure_seqs[seq_type].append(processed_seq) | |
if self.problem_type == 'multi_label_classification': | |
label_list = e['label'].split(',') | |
e['label'] = [int(l) for l in label_list] | |
binary_list = [0] * self.num_labels | |
for index in e['label']: | |
binary_list[index] = 1 | |
e['label'] = binary_list | |
# Process labels | |
labels.append(e["label"]) | |
# Tokenize sequences | |
if "ProSST" in self.plm_model: | |
batch = self.tokenize_sequences(aa_seqs, structure_seqs, str_tokens) | |
else: | |
batch = self.tokenize_sequences(aa_seqs, structure_seqs) | |
# Add labels to batch | |
batch["label"] = torch.as_tensor( | |
labels, | |
dtype=torch.float if self.problem_type == 'regression' else torch.long | |
) | |
return batch | |
def process_sequence(self, seq: str) -> str: | |
"""Process sequence based on model type.""" | |
if 'prot_bert' in self.plm_model or "prot_t5" in self.plm_model: | |
seq = " ".join(list(seq)) | |
seq = re.sub(r"[UZOB]", "X", seq) | |
return seq | |
def process_esm3_structure_seq(self, seq: List[int]) -> torch.Tensor: | |
"""Process ESM3 structure sequence.""" | |
return torch.tensor([VQVAE_SPECIAL_TOKENS["BOS"]] + seq + [VQVAE_SPECIAL_TOKENS["EOS"]]) | |
def process_stru_tokens(self, seq:List[int]) -> torch.Tensor: | |
"""Process ProSST structure token.""" | |
if isinstance(seq, str): | |
seq_clean = seq.strip("[]").replace(" ","") | |
tokens = list(map(int, seq_clean.split(','))) if seq_clean else [] | |
elif isinstance(seq, (list, tuple)): | |
tokens = [int(x) for x in seq] | |
stru_tokens = [int(num) for num in tokens] | |
return torch.tensor(stru_tokens) | |
def tokenize_sequences( | |
self, | |
aa_seqs: List[str], | |
structure_seqs: Dict[str, List[str]], | |
str_tokens: List[str] = None, | |
) -> Dict[str, torch.Tensor]: | |
"""Tokenize all sequences.""" | |
# Process amino acid sequences | |
if "esm1b" in self.plm_model or "esm1v" in self.plm_model: | |
self.max_length = 1022 | |
aa_encodings = self.tokenizer( | |
aa_seqs, | |
padding=True, | |
truncation=True if self.max_length else False, | |
max_length=self.max_length, | |
return_tensors="pt" | |
) | |
aa_max_length = len(aa_encodings["input_ids"][0]) | |
padded_tokens = [] | |
if str_tokens: | |
for tokens in str_tokens: | |
struct_sequence = [int(num) for num in tokens] | |
padded_tokens.append(struct_sequence + [0] * (aa_max_length - len(struct_sequence) - 2)) | |
batch = { | |
"aa_seq_input_ids": aa_encodings["input_ids"], | |
"aa_seq_attention_mask": aa_encodings["attention_mask"], | |
"aa_seq_stru_tokens": torch.tensor(padded_tokens, dtype=torch.long) | |
} | |
else: | |
batch = { | |
"aa_seq_input_ids": aa_encodings["input_ids"], | |
"aa_seq_attention_mask": aa_encodings["attention_mask"] | |
} | |
# Process structure sequences if provided | |
for seq_type, seqs in structure_seqs.items(): | |
if not seqs: | |
continue | |
if seq_type == 'esm3_structure_seq': | |
# ESM3 structure sequences are already tokenized | |
structure_tokens = torch.stack(seqs) | |
# Pad sequences to max length | |
max_len = max(len(seq) for seq in seqs) | |
padded_tokens = torch.zeros(len(seqs), max_len, dtype=torch.long) | |
attention_mask = torch.zeros(len(seqs), max_len, dtype=torch.long) | |
for i, seq in enumerate(seqs): | |
seq_len = len(seq) | |
padded_tokens[i, :seq_len] = seq | |
attention_mask[i, :seq_len] = 1 | |
batch[f"{seq_type}_input_ids"] = padded_tokens | |
batch[f"{seq_type}_attention_mask"] = attention_mask | |
else: | |
# Tokenize other structure sequences | |
structure_encodings = self.tokenizer( | |
seqs, | |
padding=True, | |
truncation=True if self.max_length else False, | |
max_length=self.max_length, | |
return_tensors="pt" | |
) | |
batch[f"{seq_type}_input_ids"] = structure_encodings["input_ids"] | |
batch[f"{seq_type}_attention_mask"] = structure_encodings["attention_mask"] | |
return batch |