Spaces:
Sleeping
Sleeping
# coding=utf-8 | |
# Copyright 2024 Google LLC | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Model classes for MetricX, modified from the T5 versions in HF.""" | |
import copy | |
import dataclasses | |
from typing import Optional, Tuple, Union | |
import warnings | |
import torch | |
from torch import nn | |
import transformers | |
import transformers.modeling_outputs | |
BaseModelOutput = transformers.modeling_outputs.BaseModelOutput | |
ModelOutput = transformers.modeling_outputs.ModelOutput | |
MT5Config = transformers.models.mt5.modeling_mt5.MT5Config | |
MT5PreTrainedModel = transformers.models.mt5.modeling_mt5.MT5PreTrainedModel | |
MT5Stack = transformers.models.mt5.modeling_mt5.MT5Stack | |
__HEAD_MASK_WARNING_MSG = ( | |
transformers.models.mt5.modeling_mt5.__HEAD_MASK_WARNING_MSG # pylint: disable=protected-access | |
) | |
class MT5ForRegressionOutput(ModelOutput): | |
loss: Optional[torch.FloatTensor] = None | |
predictions: torch.FloatTensor = None | |
class MT5ForRegression(MT5PreTrainedModel): | |
"""MT5 model for regression.""" | |
def __init__(self, config: MT5Config): | |
super().__init__(config) | |
self.model_dim = config.d_model | |
self.shared = nn.Embedding(config.vocab_size, config.d_model) | |
encoder_config = copy.deepcopy(config) | |
encoder_config.is_decoder = False | |
encoder_config.use_cache = False | |
encoder_config.is_encoder_decoder = False | |
self.encoder = MT5Stack(encoder_config, self.shared) | |
decoder_config = copy.deepcopy(config) | |
decoder_config.is_decoder = True | |
decoder_config.is_encoder_decoder = False | |
decoder_config.num_layers = config.num_decoder_layers | |
self.decoder = MT5Stack(decoder_config, self.shared) | |
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False) | |
# Initialize weights and apply final processing | |
self.post_init() | |
# Model parallel | |
self.model_parallel = False | |
self.device_map = None | |
def forward( | |
self, | |
input_ids: Optional[torch.LongTensor] = None, | |
attention_mask: Optional[torch.FloatTensor] = None, | |
decoder_attention_mask: Optional[torch.BoolTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
decoder_head_mask: Optional[torch.FloatTensor] = None, | |
cross_attn_head_mask: Optional[torch.Tensor] = None, | |
encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None, | |
inputs_embeds: Optional[torch.FloatTensor] = None, | |
decoder_inputs_embeds: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.FloatTensor] = None, | |
use_cache: Optional[bool] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
) -> Union[Tuple[torch.FloatTensor], MT5ForRegressionOutput]: | |
use_cache = use_cache if use_cache is not None else self.config.use_cache | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
# FutureWarning: head_mask was separated into two input args - head_mask, | |
# decoder_head_mask | |
if head_mask is not None and decoder_head_mask is None: | |
if self.config.num_layers == self.config.num_decoder_layers: | |
warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning) | |
decoder_head_mask = head_mask | |
# Encode if needed (training, first prediction pass) | |
if encoder_outputs is None: | |
# Convert encoder inputs in embeddings if needed | |
encoder_outputs = self.encoder( | |
input_ids=input_ids, | |
attention_mask=attention_mask, | |
inputs_embeds=inputs_embeds, | |
head_mask=head_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
elif return_dict and not isinstance(encoder_outputs, BaseModelOutput): | |
encoder_outputs = BaseModelOutput( | |
last_hidden_state=encoder_outputs[0], | |
hidden_states=encoder_outputs[1] | |
if len(encoder_outputs) > 1 | |
else None, | |
attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None, | |
) | |
hidden_states = encoder_outputs[0] | |
if self.model_parallel: | |
torch.cuda.set_device(self.decoder.first_device) | |
# Create 1 step of dummy input for the decoder. | |
batch_size = input_ids.size(0) | |
decoder_input_ids = torch.LongTensor([0]).repeat(batch_size).reshape(-1, 1) | |
if torch.cuda.is_available(): | |
decoder_input_ids = decoder_input_ids.to(torch.device("cuda")) | |
# Set device for model parallelism | |
if self.model_parallel: | |
torch.cuda.set_device(self.decoder.first_device) | |
hidden_states = hidden_states.to(self.decoder.first_device) | |
if decoder_input_ids is not None: | |
decoder_input_ids = decoder_input_ids.to(self.decoder.first_device) | |
if attention_mask is not None: | |
attention_mask = attention_mask.to(self.decoder.first_device) | |
if decoder_attention_mask is not None: | |
decoder_attention_mask = decoder_attention_mask.to( | |
self.decoder.first_device | |
) | |
# Decode | |
decoder_outputs = self.decoder( | |
input_ids=decoder_input_ids, | |
attention_mask=decoder_attention_mask, | |
inputs_embeds=decoder_inputs_embeds, | |
past_key_values=past_key_values, | |
encoder_hidden_states=hidden_states, | |
encoder_attention_mask=attention_mask, | |
head_mask=decoder_head_mask, | |
cross_attn_head_mask=cross_attn_head_mask, | |
use_cache=use_cache, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = decoder_outputs[0] | |
# Set device for model parallelism | |
if self.model_parallel: | |
torch.cuda.set_device(self.encoder.first_device) | |
self.lm_head = self.lm_head.to(self.encoder.first_device) | |
sequence_output = sequence_output.to(self.lm_head.weight.device) | |
if self.config.tie_word_embeddings: | |
# Rescale output before projecting on vocab | |
# See | |
# https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586 | |
sequence_output = sequence_output * (self.model_dim**-0.5) | |
lm_logits = self.lm_head(sequence_output) | |
# 250089 = <extra_id_10> | |
predictions = lm_logits[:, 0, 250089] | |
# Clip to 0 to 25 | |
predictions = torch.clamp(predictions, 0, 25) | |
loss = None | |
if labels is not None: | |
loss_fct = nn.MSELoss() | |
# move labels to correct device to enable PP | |
labels = labels.to(predictions.device) | |
loss = loss_fct(predictions.view(-1), labels.view(-1)) | |
return MT5ForRegressionOutput( | |
loss=loss, | |
predictions=predictions, | |
) | |