File size: 2,246 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch

from dataclasses import dataclass
from typing import List, Dict, Any
from transformers import DataCollatorWithPadding, BatchEncoding


def _unpack_doc_values(features: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
    doc_examples = []
    for f in features:
        keys = list(f.keys())
        lists_per_key = len(f[keys[0]])
        for idx in range(lists_per_key):
            doc_examples.append({k: f[k][idx] for k in keys})
    return doc_examples


@dataclass
class BiencoderCollator(DataCollatorWithPadding):

    def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
        q_prefix, d_prefix = 'q_', 'd_'
        query_examples = [{k[len(q_prefix):]: v for k, v in f.items() if k.startswith(q_prefix)} for f in features]
        doc_examples = _unpack_doc_values(
            [{k[len(d_prefix):]: v for k, v in f.items() if k.startswith(d_prefix)} for f in features])
        assert len(doc_examples) % len(query_examples) == 0, \
            '{} doc and {} queries'.format(len(doc_examples), len(query_examples))

        # already truncated during tokenization
        q_collated = self.tokenizer.pad(
            query_examples,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors)
        d_collated = self.tokenizer.pad(
            doc_examples,
            padding=self.padding,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors=self.return_tensors)

        # merge into a single BatchEncoding by adding prefix
        for k in list(q_collated.keys()):
            q_collated[q_prefix + k] = q_collated[k]
            del q_collated[k]
        for k in d_collated:
            q_collated[d_prefix + k] = d_collated[k]

        merged_batch_dict = q_collated
        # dummy placeholder for field "labels", won't use it to compute loss
        labels = torch.zeros(len(query_examples), dtype=torch.long)
        merged_batch_dict['labels'] = labels

        if 'kd_labels' in features[0]:
            kd_labels = torch.stack([torch.tensor(f['kd_labels']) for f in features], dim=0).float()
            merged_batch_dict['kd_labels'] = kd_labels
        return merged_batch_dict