Spaces:
Running
Running
import os | |
import torch.nn as nn | |
import torch | |
import torch.nn.functional as F | |
from torch.nn import CrossEntropyLoss, BCEWithLogitsLoss | |
import numpy as np | |
from utils import MyTokenizer | |
from transformers import ( | |
RobertaConfig, | |
RobertaModel, | |
RobertaTokenizer, | |
BartConfig, | |
BartForConditionalGeneration, | |
BartTokenizer, | |
T5Config, | |
T5ForConditionalGeneration, | |
T5Tokenizer, | |
) | |
import logging | |
logger = logging.getLogger(__name__) | |
class ReviewerModel(T5ForConditionalGeneration): | |
def __init__(self, config): | |
super().__init__(config) | |
self.cls_head = nn.Linear(self.config.d_model, 2, bias=True) | |
self.init() | |
def init(self): | |
nn.init.xavier_uniform_(self.lm_head.weight) | |
factor = self.config.initializer_factor | |
self.cls_head.weight.data.normal_(mean=0.0, \ | |
std=factor * ((self.config.d_model) ** -0.5)) | |
self.cls_head.bias.data.zero_() | |
def forward( | |
self, *argv, **kwargs | |
): | |
r""" | |
Doc from Huggingface transformers: | |
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`): | |
Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[-100, 0, ..., | |
config.vocab_size - 1]`. All labels set to ``-100`` are ignored (masked), the loss is only computed for | |
labels in ``[0, ..., config.vocab_size]`` | |
Returns: | |
Examples:: | |
>>> from transformers import T5Tokenizer, T5ForConditionalGeneration | |
>>> tokenizer = T5Tokenizer.from_pretrained('t5-small') | |
>>> model = T5ForConditionalGeneration.from_pretrained('t5-small') | |
>>> # training | |
>>> input_ids = tokenizer('The <extra_id_0> walks in <extra_id_1> park', return_tensors='pt').input_ids | |
>>> labels = tokenizer('<extra_id_0> cute dog <extra_id_1> the <extra_id_2>', return_tensors='pt').input_ids | |
>>> outputs = model(input_ids=input_ids, labels=labels) | |
>>> loss = outputs.loss | |
>>> logits = outputs.logits | |
>>> # inference | |
>>> input_ids = tokenizer("summarize: studies have shown that owning a dog is good for you", return_tensors="pt").input_ids # Batch size 1 | |
>>> outputs = model.generate(input_ids) | |
>>> print(tokenizer.decode(outputs[0], skip_special_tokens=True)) | |
>>> # studies have shown that owning a dog is good for you. | |
""" | |
if "cls" in kwargs: | |
assert ( | |
"input_ids" in kwargs and \ | |
"labels" in kwargs and \ | |
"attention_mask" in kwargs | |
) | |
return self.cls( | |
input_ids=kwargs["input_ids"], | |
labels=kwargs["labels"], | |
attention_mask=kwargs["attention_mask"], | |
) | |
if "input_labels" in kwargs: | |
assert ( | |
"input_ids" in kwargs and \ | |
"input_labels" in kwargs and \ | |
"decoder_input_ids" in kwargs and \ | |
"attention_mask" in kwargs and \ | |
"decoder_attention_mask" in kwargs | |
), "Please give these arg keys." | |
input_ids = kwargs["input_ids"] | |
input_labels = kwargs["input_labels"] | |
decoder_input_ids = kwargs["decoder_input_ids"] | |
attention_mask = kwargs["attention_mask"] | |
decoder_attention_mask = kwargs["decoder_attention_mask"] | |
if "encoder_loss" not in kwargs: | |
encoder_loss = True | |
else: | |
encoder_loss = kwargs["encoder_loss"] | |
return self.review_forward(input_ids, input_labels, decoder_input_ids, attention_mask, decoder_attention_mask, encoder_loss) | |
return super().forward(*argv, **kwargs) | |
def cls( | |
self, | |
input_ids, | |
labels, | |
attention_mask, | |
): | |
encoder_outputs = self.encoder( \ | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_attentions=False, | |
return_dict=False | |
) | |
hidden_states = encoder_outputs[0] | |
first_hidden = hidden_states[:, 0, :] | |
first_hidden = nn.Dropout(0.3)(first_hidden) | |
logits = self.cls_head(first_hidden) | |
loss_fct = CrossEntropyLoss() | |
if labels != None: | |
loss = loss_fct(logits, labels) | |
return loss | |
return logits | |
def review_forward( | |
self, | |
input_ids, | |
input_labels, | |
decoder_input_ids, | |
attention_mask, | |
decoder_attention_mask, | |
encoder_loss=True | |
): | |
encoder_outputs = self.encoder( \ | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
output_attentions=False, | |
return_dict=False | |
) | |
hidden_states = encoder_outputs[0] | |
decoder_inputs = self._shift_right(decoder_input_ids) | |
# Decode | |
decoder_outputs = self.decoder( | |
input_ids=decoder_inputs, | |
attention_mask=decoder_attention_mask, | |
encoder_hidden_states=hidden_states, | |
encoder_attention_mask=attention_mask, | |
output_attentions=False, | |
return_dict=False | |
) | |
sequence_output = decoder_outputs[0] | |
if self.config.tie_word_embeddings: # this is True default | |
sequence_output = sequence_output * (self.model_dim ** -0.5) | |
if encoder_loss: | |
# print(self.encoder.get_input_embeddings().weight.shape) | |
cls_logits = nn.functional.linear(hidden_states, self.encoder.get_input_embeddings().weight) | |
# cls_logits = self.cls_head(hidden_states) | |
lm_logits = self.lm_head(sequence_output) | |
if decoder_input_ids is not None: | |
lm_loss_fct = CrossEntropyLoss(ignore_index=0) # Warning: PAD_ID should be 0 | |
loss = lm_loss_fct(lm_logits.view(-1, lm_logits.size(-1)), decoder_input_ids.view(-1)) | |
if encoder_loss and input_labels is not None: | |
cls_loss_fct = CrossEntropyLoss(ignore_index=-100) | |
loss += cls_loss_fct(cls_logits.view(-1, cls_logits.size(-1)), input_labels.view(-1)) | |
return loss | |
return cls_logits, lm_logits | |
def get_model_size(model): | |
model_parameters = filter(lambda p: p.requires_grad, model.parameters()) | |
model_size = sum([np.prod(p.size()) for p in model_parameters]) | |
return "{}M".format(round(model_size / 1e6)) | |
def build_or_load_gen_model(args): | |
config_class, model_class, tokenizer_class = T5Config, ReviewerModel, RobertaTokenizer | |
config = config_class.from_pretrained(args.model_name_or_path) | |
tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) | |
model = model_class.from_pretrained(args.model_name_or_path, config=config) | |
tokenizer.special_dict = { | |
f"<e{i}>" : tokenizer.get_vocab()[f"<e{i}>"] for i in range(99, -1, -1) | |
} | |
tokenizer.mask_id = tokenizer.get_vocab()["<mask>"] | |
tokenizer.bos_id = tokenizer.get_vocab()["<s>"] | |
tokenizer.pad_id = tokenizer.get_vocab()["<pad>"] | |
tokenizer.eos_id = tokenizer.get_vocab()["</s>"] | |
tokenizer.msg_id = tokenizer.get_vocab()["<msg>"] | |
tokenizer.keep_id = tokenizer.get_vocab()["<keep>"] | |
tokenizer.add_id = tokenizer.get_vocab()["<add>"] | |
tokenizer.del_id = tokenizer.get_vocab()["<del>"] | |
tokenizer.start_id = tokenizer.get_vocab()["<start>"] | |
tokenizer.end_id = tokenizer.get_vocab()["<end>"] | |
logger.info( | |
"Finish loading model [%s] from %s", | |
get_model_size(model), | |
args.model_name_or_path, | |
) | |
if args.load_model_path is not None: | |
model_path = os.path.join(args.load_model_path, "pytorch_model.bin") | |
logger.info("Reload model from {}".format(model_path)) | |
try: | |
model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
except RuntimeError: | |
saved = model.cls_head | |
model.cls_head = None | |
model.load_state_dict(torch.load(model_path, map_location="cpu")) | |
model.cls_head = saved | |
model.to(args.local_rank) | |
return config, model, tokenizer | |