Tzktz's picture
Upload 7664 files
6fc683c verified
raw
history blame contribute delete
892 Bytes
import torch
from dataclasses import dataclass
from typing import List, Dict, Any
from transformers import BatchEncoding, DataCollatorWithPadding
@dataclass
class CrossEncoderCollator(DataCollatorWithPadding):
def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
unpack_features = []
for ex in features:
keys = list(ex.keys())
# assert all(len(ex[k]) == 8 for k in keys)
for idx in range(len(ex[keys[0]])):
unpack_features.append({k: ex[k][idx] for k in keys})
collated_batch_dict = self.tokenizer.pad(
unpack_features,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors)
collated_batch_dict['labels'] = torch.zeros(len(features), dtype=torch.long)
return collated_batch_dict