Spaces:
Sleeping
Sleeping
File size: 5,669 Bytes
6fc683c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import copy
import os
import torch
import torch.nn as nn
from contextlib import nullcontext
from torch import Tensor
from torch.distributions import Categorical
from typing import Dict, Optional, Tuple
from dataclasses import dataclass
from transformers import AutoModelForMaskedLM, ElectraModel
from transformers.modeling_outputs import MaskedLMOutput, ModelOutput
from transformers.models.bert import BertForMaskedLM
from logger_config import logger
from config import Arguments
from utils import slice_batch_dict
@dataclass
class ReplaceLMOutput(ModelOutput):
loss: Optional[Tensor] = None
encoder_mlm_loss: Optional[Tensor] = None
decoder_mlm_loss: Optional[Tensor] = None
g_mlm_loss: Optional[Tensor] = None
replace_ratio: Optional[Tensor] = None
class ReplaceLM(nn.Module):
def __init__(self, args: Arguments,
bert: BertForMaskedLM):
super(ReplaceLM, self).__init__()
self.encoder = bert
self.decoder = copy.deepcopy(self.encoder.bert.encoder.layer[-args.rlm_decoder_layers:])
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.generator: ElectraModel = AutoModelForMaskedLM.from_pretrained(args.rlm_generator_model_name)
if args.rlm_freeze_generator:
self.generator.eval()
self.generator.requires_grad_(False)
self.args = args
from trainers.rlm_trainer import ReplaceLMTrainer
self.trainer: Optional[ReplaceLMTrainer] = None
def forward(self, model_input: Dict[str, torch.Tensor]) -> ReplaceLMOutput:
enc_prefix, dec_prefix = 'enc_', 'dec_'
encoder_inputs = slice_batch_dict(model_input, enc_prefix)
decoder_inputs = slice_batch_dict(model_input, dec_prefix)
labels = model_input['labels']
enc_sampled_input_ids, g_mlm_loss = self._replace_tokens(encoder_inputs)
if self.args.rlm_freeze_generator:
g_mlm_loss = torch.tensor(0, dtype=torch.float, device=g_mlm_loss.device)
dec_sampled_input_ids, _ = self._replace_tokens(decoder_inputs, no_grad=True)
encoder_inputs['input_ids'] = enc_sampled_input_ids
decoder_inputs['input_ids'] = dec_sampled_input_ids
# use the un-masked version of labels
encoder_inputs['labels'] = labels
decoder_inputs['labels'] = labels
is_replaced = (encoder_inputs['input_ids'] != labels) & (labels >= 0)
replace_cnt = is_replaced.long().sum().item()
total_cnt = (encoder_inputs['attention_mask'] == 1).long().sum().item()
replace_ratio = torch.tensor(replace_cnt / total_cnt, device=g_mlm_loss.device)
encoder_out: MaskedLMOutput = self.encoder(
**encoder_inputs,
output_hidden_states=True,
return_dict=True)
# batch_size x 1 x hidden_dim
cls_hidden = encoder_out.hidden_states[-1][:, :1]
# batch_size x seq_length x embed_dim
dec_inputs_embeds = self.encoder.bert.embeddings(decoder_inputs['input_ids'])
hiddens = torch.cat([cls_hidden, dec_inputs_embeds[:, 1:]], dim=1)
attention_mask = self.encoder.get_extended_attention_mask(
encoder_inputs['attention_mask'],
encoder_inputs['attention_mask'].shape,
encoder_inputs['attention_mask'].device
)
for layer in self.decoder:
layer_out = layer(hiddens, attention_mask)
hiddens = layer_out[0]
decoder_mlm_loss = self.mlm_loss(hiddens, labels)
loss = decoder_mlm_loss + encoder_out.loss + g_mlm_loss * self.args.rlm_generator_mlm_weight
return ReplaceLMOutput(loss=loss,
encoder_mlm_loss=encoder_out.loss.detach(),
decoder_mlm_loss=decoder_mlm_loss.detach(),
g_mlm_loss=g_mlm_loss.detach(),
replace_ratio=replace_ratio)
def _replace_tokens(self, batch_dict: Dict[str, torch.Tensor],
no_grad: bool = False) -> Tuple:
with torch.no_grad() if self.args.rlm_freeze_generator or no_grad else nullcontext():
outputs: MaskedLMOutput = self.generator(
**batch_dict,
return_dict=True)
with torch.no_grad():
sampled_input_ids = Categorical(logits=outputs.logits).sample()
is_mask = (batch_dict['labels'] >= 0).long()
sampled_input_ids = batch_dict['input_ids'] * (1 - is_mask) + sampled_input_ids * is_mask
return sampled_input_ids.long(), outputs.loss
def mlm_loss(self, hiddens: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
pred_scores = self.encoder.cls(hiddens)
mlm_loss = self.cross_entropy(
pred_scores.view(-1, self.encoder.config.vocab_size),
labels.view(-1))
return mlm_loss
@classmethod
def from_pretrained(cls, all_args: Arguments,
model_name_or_path: str, *args, **kwargs):
hf_model = AutoModelForMaskedLM.from_pretrained(model_name_or_path, *args, **kwargs)
model = cls(all_args, hf_model)
decoder_save_path = os.path.join(model_name_or_path, 'decoder.pt')
if os.path.exists(decoder_save_path):
logger.info('loading extra weights from local files')
state_dict = torch.load(decoder_save_path, map_location="cpu")
model.decoder.load_state_dict(state_dict)
return model
def save_pretrained(self, output_dir: str):
self.encoder.save_pretrained(output_dir)
torch.save(self.decoder.state_dict(), os.path.join(output_dir, 'decoder.pt'))
|