codereviewer / models.py
shekkari21's picture
Deploy to HF Space
2a21e9f
import os
import torch.nn as nn
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss
import numpy as np
from utils import MyTokenizer
from transformers import (
RobertaConfig,
RobertaModel,
RobertaTokenizer,
BartConfig,
BartForConditionalGeneration,
BartTokenizer,
T5Config,
T5ForConditionalGeneration,
T5Tokenizer,
)
import logging
logger = logging.getLogger(__name__)
class ReviewerModel(T5ForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
self.cls_head = nn.Linear(self.config.d_model, 2, bias=True)
self.init()
def init(self):
nn.init.xavier_uniform_(self.lm_head.weight)
factor = self.config.initializer_factor
self.cls_head.weight.data.normal_(mean=0.0, \
std=factor * ((self.config.d_model) ** -0.5))
self.cls_head.bias.data.zero_()
def forward(
self, *argv, **kwargs
):
r"""
Doc from Huggingface transformers:
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ...,
config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for
labels in ``[0, ..., config.vocab_size]``
Returns:
Examples::
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small')
>>> model = T5ForConditionalGeneration.from_pretrained('t5-small')
>>> # training
>>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids
>>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids
>>> outputs = model(input_ids=input_ids, labels=labels)
>>> loss = outputs.loss
>>> logits = outputs.logits
>>> # inference
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1
>>> outputs = model.generate(input_ids)
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
>>> # studies have shown that owning a dog is good for you.
"""
if "cls" in kwargs:
assert (
"input_ids" in kwargs and \
"labels" in kwargs and \
"attention_mask" in kwargs
)
return self.cls(
input_ids=kwargs["input_ids"],
labels=kwargs["labels"],
attention_mask=kwargs["attention_mask"],
)
if "input_labels" in kwargs:
assert (
"input_ids" in kwargs and \
"input_labels" in kwargs and \
"decoder_input_ids" in kwargs and \
"attention_mask" in kwargs and \
"decoder_attention_mask" in kwargs
), "Please give these arg keys."
input_ids = kwargs["input_ids"]
input_labels = kwargs["input_labels"]
decoder_input_ids = kwargs["decoder_input_ids"]
attention_mask = kwargs["attention_mask"]
decoder_attention_mask = kwargs["decoder_attention_mask"]
if "encoder_loss" not in kwargs:
encoder_loss = True
else:
encoder_loss = kwargs["encoder_loss"]
return self.review_forward(input_ids, input_labels, decoder_input_ids, attention_mask, decoder_attention_mask, encoder_loss)
return super().forward(*argv, **kwargs)
def cls(
self,
input_ids,
labels,
attention_mask,
):
encoder_outputs = self.encoder( \
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=False,
return_dict=False
)
hidden_states = encoder_outputs[0]
first_hidden = hidden_states[:, 0, :]
first_hidden = nn.Dropout(0.3)(first_hidden)
logits = self.cls_head(first_hidden)
loss_fct = CrossEntropyLoss()
if labels != None:
loss = loss_fct(logits, labels)
return loss
return logits
def review_forward(
self,
input_ids,
input_labels,
decoder_input_ids,
attention_mask,
decoder_attention_mask,
encoder_loss=True
):
encoder_outputs = self.encoder( \
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=False,
return_dict=False
)
hidden_states = encoder_outputs[0]
decoder_inputs = self._shift_right(decoder_input_ids)
# Decode
decoder_outputs = self.decoder(
input_ids=decoder_inputs,
attention_mask=decoder_attention_mask,
encoder_hidden_states=hidden_states,
encoder_attention_mask=attention_mask,
output_attentions=False,
return_dict=False
)
sequence_output = decoder_outputs[0]
if self.config.tie_word_embeddings: # this is True default
sequence_output = sequence_output * (self.model_dim ** -0.5)
if encoder_loss:
# print(self.encoder.get_input_embeddings().weight.shape)
cls_logits = nn.functional.linear(hidden_states, self.encoder.get_input_embeddings().weight)
# cls_logits = self.cls_head(hidden_states)
lm_logits = self.lm_head(sequence_output)
if decoder_input_ids is not None:
lm_loss_fct = CrossEntropyLoss(ignore_index=0) # Warning: PAD_ID should be 0
loss = lm_loss_fct(lm_logits.view(-1, lm_logits.size(-1)), decoder_input_ids.view(-1))
if encoder_loss and input_labels is not None:
cls_loss_fct = CrossEntropyLoss(ignore_index=-100)
loss += cls_loss_fct(cls_logits.view(-1, cls_logits.size(-1)), input_labels.view(-1))
return loss
return cls_logits, lm_logits
def get_model_size(model):
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
model_size = sum([np.prod(p.size()) for p in model_parameters])
return "{}M".format(round(model_size / 1e6))
def build_or_load_gen_model(args):
config_class, model_class, tokenizer_class = T5Config, ReviewerModel, RobertaTokenizer
config = config_class.from_pretrained(args.model_name_or_path)
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path)
model = model_class.from_pretrained(args.model_name_or_path, config=config)
tokenizer.special_dict = {
f"<e{i}>" : tokenizer.get_vocab()[f"<e{i}>"] for i in range(99, -1, -1)
}
tokenizer.mask_id = tokenizer.get_vocab()["<mask>"]
tokenizer.bos_id = tokenizer.get_vocab()["<s>"]
tokenizer.pad_id = tokenizer.get_vocab()["<pad>"]
tokenizer.eos_id = tokenizer.get_vocab()["</s>"]
tokenizer.msg_id = tokenizer.get_vocab()["<msg>"]
tokenizer.keep_id = tokenizer.get_vocab()["<keep>"]
tokenizer.add_id = tokenizer.get_vocab()["<add>"]
tokenizer.del_id = tokenizer.get_vocab()["<del>"]
tokenizer.start_id = tokenizer.get_vocab()["<start>"]
tokenizer.end_id = tokenizer.get_vocab()["<end>"]
logger.info(
"Finish loading model [%s] from %s",
get_model_size(model),
args.model_name_or_path,
)
if args.load_model_path is not None:
model_path = os.path.join(args.load_model_path, "pytorch_model.bin")
logger.info("Reload model from {}".format(model_path))
try:
model.load_state_dict(torch.load(model_path, map_location="cpu"))
except RuntimeError:
saved = model.cls_head
model.cls_head = None
model.load_state_dict(torch.load(model_path, map_location="cpu"))
model.cls_head = saved
model.to(args.local_rank)
return config, model, tokenizer