Spaces:
Sleeping
Sleeping
File size: 4,082 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import torch
import torch.nn as nn
from typing import Optional, Dict
from transformers import (
PreTrainedModel,
AutoModelForSequenceClassification
)
from transformers.modeling_outputs import SequenceClassifierOutput
from config import Arguments
class Reranker(nn.Module):
def __init__(self, hf_model: PreTrainedModel, args: Arguments):
super().__init__()
self.hf_model = hf_model
self.args = args
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
def forward(self, batch: Dict[str, torch.Tensor]) -> SequenceClassifierOutput:
input_batch_dict = {k: v for k, v in batch.items() if k != 'labels'}
if self.args.rerank_forward_factor > 1:
assert torch.sum(batch['labels']).long().item() == 0
assert all(len(v.shape) == 2 for v in input_batch_dict.values())
is_train = self.hf_model.training
self.hf_model.eval()
with torch.no_grad():
outputs: SequenceClassifierOutput = self.hf_model(**input_batch_dict, return_dict=True)
outputs.logits = outputs.logits.view(-1, self.args.train_n_passages)
# make sure the target passage is not masked out
outputs.logits[:, 0].fill_(float('inf'))
k = self.args.train_n_passages // self.args.rerank_forward_factor
_, topk_indices = torch.topk(outputs.logits, k=k, dim=-1, largest=True)
topk_indices += self.args.train_n_passages * torch.arange(0, topk_indices.shape[0],
dtype=torch.long,
device=topk_indices.device).unsqueeze(-1)
topk_indices = topk_indices.view(-1)
input_batch_dict = {k: v.index_select(dim=0, index=topk_indices) for k, v in input_batch_dict.items()}
self.hf_model.train(is_train)
n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor
if self.args.rerank_use_rdrop and self.training:
input_batch_dict = {k: torch.cat([v, v], dim=0) for k, v in input_batch_dict.items()}
outputs: SequenceClassifierOutput = self.hf_model(**input_batch_dict, return_dict=True)
if self.args.rerank_use_rdrop and self.training:
logits = outputs.logits.view(2, -1, n_psg_per_query)
outputs.logits = logits[0, :, :].contiguous()
log_prob = torch.log_softmax(logits, dim=2)
log_prob1, log_prob2 = log_prob[0, :, :], log_prob[1, :, :]
rdrop_loss = 0.5 * (self.kl_loss_fn(log_prob1, log_prob2) + self.kl_loss_fn(log_prob2, log_prob1))
ce_loss = 0.5 * (self.cross_entropy(log_prob1, batch['labels'])
+ self.cross_entropy(log_prob2, batch['labels']))
outputs.loss = rdrop_loss + ce_loss
else:
outputs.logits = outputs.logits.view(-1, n_psg_per_query)
loss = self.cross_entropy(outputs.logits, batch['labels'])
outputs.loss = loss
return outputs
@classmethod
def from_pretrained(cls, all_args: Arguments, *args, **kwargs):
hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
return cls(hf_model, all_args)
def save_pretrained(self, output_dir: str):
self.hf_model.save_pretrained(output_dir)
class RerankerForInference(nn.Module):
def __init__(self, hf_model: Optional[PreTrainedModel] = None):
super().__init__()
self.hf_model = hf_model
self.hf_model.eval()
@torch.no_grad()
def forward(self, batch) -> SequenceClassifierOutput:
return self.hf_model(**batch)
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: str):
hf_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)
return cls(hf_model)
|