Tzktz's picture
Upload 7664 files
6fc683c verified
import torch
from dataclasses import dataclass
from typing import List, Dict, Any
from transformers import DataCollatorWithPadding, BatchEncoding
def _unpack_doc_values(features: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
doc_examples = []
for f in features:
keys = list(f.keys())
lists_per_key = len(f[keys[0]])
for idx in range(lists_per_key):
doc_examples.append({k: f[k][idx] for k in keys})
return doc_examples
@dataclass
class BiencoderCollator(DataCollatorWithPadding):
def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding:
q_prefix, d_prefix = 'q_', 'd_'
query_examples = [{k[len(q_prefix):]: v for k, v in f.items() if k.startswith(q_prefix)} for f in features]
doc_examples = _unpack_doc_values(
[{k[len(d_prefix):]: v for k, v in f.items() if k.startswith(d_prefix)} for f in features])
assert len(doc_examples) % len(query_examples) == 0, \
'{} doc and {} queries'.format(len(doc_examples), len(query_examples))
# already truncated during tokenization
q_collated = self.tokenizer.pad(
query_examples,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors)
d_collated = self.tokenizer.pad(
doc_examples,
padding=self.padding,
pad_to_multiple_of=self.pad_to_multiple_of,
return_tensors=self.return_tensors)
# merge into a single BatchEncoding by adding prefix
for k in list(q_collated.keys()):
q_collated[q_prefix + k] = q_collated[k]
del q_collated[k]
for k in d_collated:
q_collated[d_prefix + k] = d_collated[k]
merged_batch_dict = q_collated
# dummy placeholder for field "labels", won't use it to compute loss
labels = torch.zeros(len(query_examples), dtype=torch.long)
merged_batch_dict['labels'] = labels
if 'kd_labels' in features[0]:
kd_labels = torch.stack([torch.tensor(f['kd_labels']) for f in features], dim=0).float()
merged_batch_dict['kd_labels'] = kd_labels
return merged_batch_dict