File size: 892 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch

from dataclasses import dataclass
from typing import List, Dict, Any
from transformers import BatchEncoding, DataCollatorWithPadding


@dataclass
class CrossEncoderCollator(DataCollatorWithPadding):

    def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
        unpack_features = []
        for ex in features:
            keys = list(ex.keys())
            # assert all(len(ex[k]) == 8 for k in keys)
            for idx in range(len(ex[keys[0]])):
                unpack_features.append({k: ex[k][idx] for k in keys})

        collated_batch_dict = self.tokenizer.pad(
            unpack_features,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors)

        collated_batch_dict['labels'] = torch.zeros(len(features), dtype=torch.long)
        return collated_batch_dict