Spaces:
Sleeping
Sleeping
import copy | |
import os | |
import torch | |
import torch.nn as nn | |
from contextlib import nullcontext | |
from torch import Tensor | |
from torch.distributions import Categorical | |
from typing import Dict, Optional, Tuple | |
from dataclasses import dataclass | |
from transformers import AutoModelForMaskedLM, ElectraModel | |
from transformers.modeling_outputs import MaskedLMOutput, ModelOutput | |
from transformers.models.bert import BertForMaskedLM | |
from logger_config import logger | |
from config import Arguments | |
from utils import slice_batch_dict | |
class ReplaceLMOutput(ModelOutput): | |
loss: Optional[Tensor] = None | |
encoder_mlm_loss: Optional[Tensor] = None | |
decoder_mlm_loss: Optional[Tensor] = None | |
g_mlm_loss: Optional[Tensor] = None | |
replace_ratio: Optional[Tensor] = None | |
class ReplaceLM(nn.Module): | |
def __init__(self, args: Arguments, | |
bert: BertForMaskedLM): | |
super(ReplaceLM, self).__init__() | |
self.encoder = bert | |
self.decoder = copy.deepcopy(self.encoder.bert.encoder.layer[-args.rlm_decoder_layers:]) | |
self.cross_entropy = nn.CrossEntropyLoss(reduction='mean') | |
self.generator: ElectraModel = AutoModelForMaskedLM.from_pretrained(args.rlm_generator_model_name) | |
if args.rlm_freeze_generator: | |
self.generator.eval() | |
self.generator.requires_grad_(False) | |
self.args = args | |
from trainers.rlm_trainer import ReplaceLMTrainer | |
self.trainer: Optional[ReplaceLMTrainer] = None | |
def forward(self, model_input: Dict[str, torch.Tensor]) -> ReplaceLMOutput: | |
enc_prefix, dec_prefix = 'enc_', 'dec_' | |
encoder_inputs = slice_batch_dict(model_input, enc_prefix) | |
decoder_inputs = slice_batch_dict(model_input, dec_prefix) | |
labels = model_input['labels'] | |
enc_sampled_input_ids, g_mlm_loss = self._replace_tokens(encoder_inputs) | |
if self.args.rlm_freeze_generator: | |
g_mlm_loss = torch.tensor(0, dtype=torch.float, device=g_mlm_loss.device) | |
dec_sampled_input_ids, _ = self._replace_tokens(decoder_inputs, no_grad=True) | |
encoder_inputs['input_ids'] = enc_sampled_input_ids | |
decoder_inputs['input_ids'] = dec_sampled_input_ids | |
# use the un-masked version of labels | |
encoder_inputs['labels'] = labels | |
decoder_inputs['labels'] = labels | |
is_replaced = (encoder_inputs['input_ids'] != labels) & (labels >= 0) | |
replace_cnt = is_replaced.long().sum().item() | |
total_cnt = (encoder_inputs['attention_mask'] == 1).long().sum().item() | |
replace_ratio = torch.tensor(replace_cnt / total_cnt, device=g_mlm_loss.device) | |
encoder_out: MaskedLMOutput = self.encoder( | |
**encoder_inputs, | |
output_hidden_states=True, | |
return_dict=True) | |
# batch_size x 1 x hidden_dim | |
cls_hidden = encoder_out.hidden_states[-1][:, :1] | |
# batch_size x seq_length x embed_dim | |
dec_inputs_embeds = self.encoder.bert.embeddings(decoder_inputs['input_ids']) | |
hiddens = torch.cat([cls_hidden, dec_inputs_embeds[:, 1:]], dim=1) | |
attention_mask = self.encoder.get_extended_attention_mask( | |
encoder_inputs['attention_mask'], | |
encoder_inputs['attention_mask'].shape, | |
encoder_inputs['attention_mask'].device | |
) | |
for layer in self.decoder: | |
layer_out = layer(hiddens, attention_mask) | |
hiddens = layer_out[0] | |
decoder_mlm_loss = self.mlm_loss(hiddens, labels) | |
loss = decoder_mlm_loss + encoder_out.loss + g_mlm_loss * self.args.rlm_generator_mlm_weight | |
return ReplaceLMOutput(loss=loss, | |
encoder_mlm_loss=encoder_out.loss.detach(), | |
decoder_mlm_loss=decoder_mlm_loss.detach(), | |
g_mlm_loss=g_mlm_loss.detach(), | |
replace_ratio=replace_ratio) | |
def _replace_tokens(self, batch_dict: Dict[str, torch.Tensor], | |
no_grad: bool = False) -> Tuple: | |
with torch.no_grad() if self.args.rlm_freeze_generator or no_grad else nullcontext(): | |
outputs: MaskedLMOutput = self.generator( | |
**batch_dict, | |
return_dict=True) | |
with torch.no_grad(): | |
sampled_input_ids = Categorical(logits=outputs.logits).sample() | |
is_mask = (batch_dict['labels'] >= 0).long() | |
sampled_input_ids = batch_dict['input_ids'] * (1 - is_mask) + sampled_input_ids * is_mask | |
return sampled_input_ids.long(), outputs.loss | |
def mlm_loss(self, hiddens: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: | |
pred_scores = self.encoder.cls(hiddens) | |
mlm_loss = self.cross_entropy( | |
pred_scores.view(-1, self.encoder.config.vocab_size), | |
labels.view(-1)) | |
return mlm_loss | |
def from_pretrained(cls, all_args: Arguments, | |
model_name_or_path: str, *args, **kwargs): | |
hf_model = AutoModelForMaskedLM.from_pretrained(model_name_or_path, *args, **kwargs) | |
model = cls(all_args, hf_model) | |
decoder_save_path = os.path.join(model_name_or_path, 'decoder.pt') | |
if os.path.exists(decoder_save_path): | |
logger.info('loading extra weights from local files') | |
state_dict = torch.load(decoder_save_path, map_location="cpu") | |
model.decoder.load_state_dict(state_dict) | |
return model | |
def save_pretrained(self, output_dir: str): | |
self.encoder.save_pretrained(output_dir) | |
torch.save(self.decoder.state_dict(), os.path.join(output_dir, 'decoder.pt')) | |