File size: 4,082 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import torch
import torch.nn as nn

from typing import Optional, Dict
from transformers import (
    PreTrainedModel,
    AutoModelForSequenceClassification
)
from transformers.modeling_outputs import SequenceClassifierOutput

from config import Arguments


class Reranker(nn.Module):
    def __init__(self, hf_model: PreTrainedModel, args: Arguments):
        super().__init__()
        self.hf_model = hf_model
        self.args = args

        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
        self.kl_loss_fn = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)

    def forward(self, batch: Dict[str, torch.Tensor]) -> SequenceClassifierOutput:
        input_batch_dict = {k: v for k, v in batch.items() if k != 'labels'}

        if self.args.rerank_forward_factor > 1:
            assert torch.sum(batch['labels']).long().item() == 0
            assert all(len(v.shape) == 2 for v in input_batch_dict.values())

            is_train = self.hf_model.training
            self.hf_model.eval()

            with torch.no_grad():
                outputs: SequenceClassifierOutput = self.hf_model(**input_batch_dict, return_dict=True)
                outputs.logits = outputs.logits.view(-1, self.args.train_n_passages)
                # make sure the target passage is not masked out
                outputs.logits[:, 0].fill_(float('inf'))

                k = self.args.train_n_passages // self.args.rerank_forward_factor
                _, topk_indices = torch.topk(outputs.logits, k=k, dim=-1, largest=True)
                topk_indices += self.args.train_n_passages * torch.arange(0, topk_indices.shape[0],
                                                                          dtype=torch.long,
                                                                          device=topk_indices.device).unsqueeze(-1)
                topk_indices = topk_indices.view(-1)

                input_batch_dict = {k: v.index_select(dim=0, index=topk_indices) for k, v in input_batch_dict.items()}

            self.hf_model.train(is_train)

        n_psg_per_query = self.args.train_n_passages // self.args.rerank_forward_factor

        if self.args.rerank_use_rdrop and self.training:
            input_batch_dict = {k: torch.cat([v, v], dim=0) for k, v in input_batch_dict.items()}

        outputs: SequenceClassifierOutput = self.hf_model(**input_batch_dict, return_dict=True)

        if self.args.rerank_use_rdrop and self.training:
            logits = outputs.logits.view(2, -1, n_psg_per_query)
            outputs.logits = logits[0, :, :].contiguous()
            log_prob = torch.log_softmax(logits, dim=2)
            log_prob1, log_prob2 = log_prob[0, :, :], log_prob[1, :, :]
            rdrop_loss = 0.5 * (self.kl_loss_fn(log_prob1, log_prob2) + self.kl_loss_fn(log_prob2, log_prob1))
            ce_loss = 0.5 * (self.cross_entropy(log_prob1, batch['labels'])
                             + self.cross_entropy(log_prob2, batch['labels']))

            outputs.loss = rdrop_loss + ce_loss
        else:
            outputs.logits = outputs.logits.view(-1, n_psg_per_query)
            loss = self.cross_entropy(outputs.logits, batch['labels'])
            outputs.loss = loss

        return outputs

    @classmethod
    def from_pretrained(cls, all_args: Arguments, *args, **kwargs):
        hf_model = AutoModelForSequenceClassification.from_pretrained(*args, **kwargs)
        return cls(hf_model, all_args)

    def save_pretrained(self, output_dir: str):
        self.hf_model.save_pretrained(output_dir)


class RerankerForInference(nn.Module):
    def __init__(self, hf_model: Optional[PreTrainedModel] = None):
        super().__init__()
        self.hf_model = hf_model
        self.hf_model.eval()

    @torch.no_grad()
    def forward(self, batch) -> SequenceClassifierOutput:
        return self.hf_model(**batch)

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path: str):
        hf_model = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path)
        return cls(hf_model)