Spaces:
Sleeping
Sleeping
import torch | |
from dataclasses import dataclass | |
from typing import List, Dict, Any | |
from transformers import DataCollatorWithPadding, BatchEncoding | |
def _unpack_doc_values(features: List[Dict[str, Any]]) -> List[Dict[str, Any]]: | |
doc_examples = [] | |
for f in features: | |
keys = list(f.keys()) | |
lists_per_key = len(f[keys[0]]) | |
for idx in range(lists_per_key): | |
doc_examples.append({k: f[k][idx] for k in keys}) | |
return doc_examples | |
class BiencoderCollator(DataCollatorWithPadding): | |
def __call__(self, features: List[Dict[str, Any]]) -> BatchEncoding: | |
q_prefix, d_prefix = 'q_', 'd_' | |
query_examples = [{k[len(q_prefix):]: v for k, v in f.items() if k.startswith(q_prefix)} for f in features] | |
doc_examples = _unpack_doc_values( | |
[{k[len(d_prefix):]: v for k, v in f.items() if k.startswith(d_prefix)} for f in features]) | |
assert len(doc_examples) % len(query_examples) == 0, \ | |
'{} doc and {} queries'.format(len(doc_examples), len(query_examples)) | |
# already truncated during tokenization | |
q_collated = self.tokenizer.pad( | |
query_examples, | |
padding=self.padding, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors=self.return_tensors) | |
d_collated = self.tokenizer.pad( | |
doc_examples, | |
padding=self.padding, | |
pad_to_multiple_of=self.pad_to_multiple_of, | |
return_tensors=self.return_tensors) | |
# merge into a single BatchEncoding by adding prefix | |
for k in list(q_collated.keys()): | |
q_collated[q_prefix + k] = q_collated[k] | |
del q_collated[k] | |
for k in d_collated: | |
q_collated[d_prefix + k] = d_collated[k] | |
merged_batch_dict = q_collated | |
# dummy placeholder for field "labels", won't use it to compute loss | |
labels = torch.zeros(len(query_examples), dtype=torch.long) | |
merged_batch_dict['labels'] = labels | |
if 'kd_labels' in features[0]: | |
kd_labels = torch.stack([torch.tensor(f['kd_labels']) for f in features], dim=0).float() | |
merged_batch_dict['kd_labels'] = kd_labels | |
return merged_batch_dict | |