Tzktz's picture
Upload 7664 files
6fc683c verified
import copy
import os
import torch
import torch.nn as nn
from contextlib import nullcontext
from torch import Tensor
from torch.distributions import Categorical
from typing import Dict, Optional, Tuple
from dataclasses import dataclass
from transformers import AutoModelForMaskedLM, ElectraModel
from transformers.modeling_outputs import MaskedLMOutput, ModelOutput
from transformers.models.bert import BertForMaskedLM
from logger_config import logger
from config import Arguments
from utils import slice_batch_dict
@dataclass
class ReplaceLMOutput(ModelOutput):
loss: Optional[Tensor] = None
encoder_mlm_loss: Optional[Tensor] = None
decoder_mlm_loss: Optional[Tensor] = None
g_mlm_loss: Optional[Tensor] = None
replace_ratio: Optional[Tensor] = None
class ReplaceLM(nn.Module):
def __init__(self, args: Arguments,
bert: BertForMaskedLM):
super(ReplaceLM, self).__init__()
self.encoder = bert
self.decoder = copy.deepcopy(self.encoder.bert.encoder.layer[-args.rlm_decoder_layers:])
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean')
self.generator: ElectraModel = AutoModelForMaskedLM.from_pretrained(args.rlm_generator_model_name)
if args.rlm_freeze_generator:
self.generator.eval()
self.generator.requires_grad_(False)
self.args = args
from trainers.rlm_trainer import ReplaceLMTrainer
self.trainer: Optional[ReplaceLMTrainer] = None
def forward(self, model_input: Dict[str, torch.Tensor]) -> ReplaceLMOutput:
enc_prefix, dec_prefix = 'enc_', 'dec_'
encoder_inputs = slice_batch_dict(model_input, enc_prefix)
decoder_inputs = slice_batch_dict(model_input, dec_prefix)
labels = model_input['labels']
enc_sampled_input_ids, g_mlm_loss = self._replace_tokens(encoder_inputs)
if self.args.rlm_freeze_generator:
g_mlm_loss = torch.tensor(0, dtype=torch.float, device=g_mlm_loss.device)
dec_sampled_input_ids, _ = self._replace_tokens(decoder_inputs, no_grad=True)
encoder_inputs['input_ids'] = enc_sampled_input_ids
decoder_inputs['input_ids'] = dec_sampled_input_ids
# use the un-masked version of labels
encoder_inputs['labels'] = labels
decoder_inputs['labels'] = labels
is_replaced = (encoder_inputs['input_ids'] != labels) & (labels >= 0)
replace_cnt = is_replaced.long().sum().item()
total_cnt = (encoder_inputs['attention_mask'] == 1).long().sum().item()
replace_ratio = torch.tensor(replace_cnt / total_cnt, device=g_mlm_loss.device)
encoder_out: MaskedLMOutput = self.encoder(
**encoder_inputs,
output_hidden_states=True,
return_dict=True)
# batch_size x 1 x hidden_dim
cls_hidden = encoder_out.hidden_states[-1][:, :1]
# batch_size x seq_length x embed_dim
dec_inputs_embeds = self.encoder.bert.embeddings(decoder_inputs['input_ids'])
hiddens = torch.cat([cls_hidden, dec_inputs_embeds[:, 1:]], dim=1)
attention_mask = self.encoder.get_extended_attention_mask(
encoder_inputs['attention_mask'],
encoder_inputs['attention_mask'].shape,
encoder_inputs['attention_mask'].device
)
for layer in self.decoder:
layer_out = layer(hiddens, attention_mask)
hiddens = layer_out[0]
decoder_mlm_loss = self.mlm_loss(hiddens, labels)
loss = decoder_mlm_loss + encoder_out.loss + g_mlm_loss * self.args.rlm_generator_mlm_weight
return ReplaceLMOutput(loss=loss,
encoder_mlm_loss=encoder_out.loss.detach(),
decoder_mlm_loss=decoder_mlm_loss.detach(),
g_mlm_loss=g_mlm_loss.detach(),
replace_ratio=replace_ratio)
def _replace_tokens(self, batch_dict: Dict[str, torch.Tensor],
no_grad: bool = False) -> Tuple:
with torch.no_grad() if self.args.rlm_freeze_generator or no_grad else nullcontext():
outputs: MaskedLMOutput = self.generator(
**batch_dict,
return_dict=True)
with torch.no_grad():
sampled_input_ids = Categorical(logits=outputs.logits).sample()
is_mask = (batch_dict['labels'] >= 0).long()
sampled_input_ids = batch_dict['input_ids'] * (1 - is_mask) + sampled_input_ids * is_mask
return sampled_input_ids.long(), outputs.loss
def mlm_loss(self, hiddens: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
pred_scores = self.encoder.cls(hiddens)
mlm_loss = self.cross_entropy(
pred_scores.view(-1, self.encoder.config.vocab_size),
labels.view(-1))
return mlm_loss
@classmethod
def from_pretrained(cls, all_args: Arguments,
model_name_or_path: str, *args, **kwargs):
hf_model = AutoModelForMaskedLM.from_pretrained(model_name_or_path, *args, **kwargs)
model = cls(all_args, hf_model)
decoder_save_path = os.path.join(model_name_or_path, 'decoder.pt')
if os.path.exists(decoder_save_path):
logger.info('loading extra weights from local files')
state_dict = torch.load(decoder_save_path, map_location="cpu")
model.decoder.load_state_dict(state_dict)
return model
def save_pretrained(self, output_dir: str):
self.encoder.save_pretrained(output_dir)
torch.save(self.decoder.state_dict(), os.path.join(output_dir, 'decoder.pt'))