File size: 5,669 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import copy
import os
import torch
import torch.nn as nn

from contextlib import nullcontext
from torch import Tensor
from torch.distributions import Categorical
from typing import Dict, Optional, Tuple
from dataclasses import dataclass
from transformers import AutoModelForMaskedLM, ElectraModel
from transformers.modeling_outputs import MaskedLMOutput, ModelOutput
from transformers.models.bert import BertForMaskedLM

from logger_config import logger
from config import Arguments
from utils import slice_batch_dict


@dataclass
class ReplaceLMOutput(ModelOutput):
    loss: Optional[Tensor] = None
    encoder_mlm_loss: Optional[Tensor] = None
    decoder_mlm_loss: Optional[Tensor] = None
    g_mlm_loss: Optional[Tensor] = None
    replace_ratio: Optional[Tensor] = None


class ReplaceLM(nn.Module):
    def __init__(self, args: Arguments,
                 bert: BertForMaskedLM):
        super(ReplaceLM, self).__init__()
        self.encoder = bert
        self.decoder = copy.deepcopy(self.encoder.bert.encoder.layer[-args.rlm_decoder_layers:])
        self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')

        self.generator: ElectraModel = AutoModelForMaskedLM.from_pretrained(args.rlm_generator_model_name)

        if args.rlm_freeze_generator:
            self.generator.eval()
            self.generator.requires_grad_(False)

        self.args = args

        from trainers.rlm_trainer import ReplaceLMTrainer
        self.trainer: Optional[ReplaceLMTrainer] = None

    def forward(self, model_input: Dict[str, torch.Tensor]) -> ReplaceLMOutput:
        enc_prefix, dec_prefix = 'enc_', 'dec_'
        encoder_inputs = slice_batch_dict(model_input, enc_prefix)
        decoder_inputs = slice_batch_dict(model_input, dec_prefix)
        labels = model_input['labels']

        enc_sampled_input_ids, g_mlm_loss = self._replace_tokens(encoder_inputs)
        if self.args.rlm_freeze_generator:
            g_mlm_loss = torch.tensor(0, dtype=torch.float, device=g_mlm_loss.device)
        dec_sampled_input_ids, _ = self._replace_tokens(decoder_inputs, no_grad=True)

        encoder_inputs['input_ids'] = enc_sampled_input_ids
        decoder_inputs['input_ids'] = dec_sampled_input_ids
        # use the un-masked version of labels
        encoder_inputs['labels'] = labels
        decoder_inputs['labels'] = labels

        is_replaced = (encoder_inputs['input_ids'] != labels) & (labels >= 0)
        replace_cnt = is_replaced.long().sum().item()
        total_cnt = (encoder_inputs['attention_mask'] == 1).long().sum().item()
        replace_ratio = torch.tensor(replace_cnt / total_cnt, device=g_mlm_loss.device)

        encoder_out: MaskedLMOutput = self.encoder(
            **encoder_inputs,
            output_hidden_states=True,
            return_dict=True)

        # batch_size x 1 x hidden_dim
        cls_hidden = encoder_out.hidden_states[-1][:, :1]
        # batch_size x seq_length x embed_dim
        dec_inputs_embeds = self.encoder.bert.embeddings(decoder_inputs['input_ids'])
        hiddens = torch.cat([cls_hidden, dec_inputs_embeds[:, 1:]], dim=1)

        attention_mask = self.encoder.get_extended_attention_mask(
            encoder_inputs['attention_mask'],
            encoder_inputs['attention_mask'].shape,
            encoder_inputs['attention_mask'].device
        )

        for layer in self.decoder:
            layer_out = layer(hiddens, attention_mask)
            hiddens = layer_out[0]

        decoder_mlm_loss = self.mlm_loss(hiddens, labels)

        loss = decoder_mlm_loss + encoder_out.loss + g_mlm_loss * self.args.rlm_generator_mlm_weight

        return ReplaceLMOutput(loss=loss,
                               encoder_mlm_loss=encoder_out.loss.detach(),
                               decoder_mlm_loss=decoder_mlm_loss.detach(),
                               g_mlm_loss=g_mlm_loss.detach(),
                               replace_ratio=replace_ratio)

    def _replace_tokens(self, batch_dict: Dict[str, torch.Tensor],
                        no_grad: bool = False) -> Tuple:
        with torch.no_grad() if self.args.rlm_freeze_generator or no_grad else nullcontext():
            outputs: MaskedLMOutput = self.generator(
                **batch_dict,
                return_dict=True)

        with torch.no_grad():
            sampled_input_ids = Categorical(logits=outputs.logits).sample()
            is_mask = (batch_dict['labels'] >= 0).long()
            sampled_input_ids = batch_dict['input_ids'] * (1 - is_mask) + sampled_input_ids * is_mask

        return sampled_input_ids.long(), outputs.loss

    def mlm_loss(self, hiddens: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
        pred_scores = self.encoder.cls(hiddens)
        mlm_loss = self.cross_entropy(
            pred_scores.view(-1, self.encoder.config.vocab_size),
            labels.view(-1))
        return mlm_loss

    @classmethod
    def from_pretrained(cls, all_args: Arguments,
                        model_name_or_path: str, *args, **kwargs):
        hf_model = AutoModelForMaskedLM.from_pretrained(model_name_or_path, *args, **kwargs)
        model = cls(all_args, hf_model)
        decoder_save_path = os.path.join(model_name_or_path, 'decoder.pt')
        if os.path.exists(decoder_save_path):
            logger.info('loading extra weights from local files')
            state_dict = torch.load(decoder_save_path, map_location="cpu")
            model.decoder.load_state_dict(state_dict)
        return model

    def save_pretrained(self, output_dir: str):
        self.encoder.save_pretrained(output_dir)
        torch.save(self.decoder.state_dict(), os.path.join(output_dir, 'decoder.pt'))