Spaces:
Sleeping
Sleeping
import random | |
from typing import Optional, Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.nn import CrossEntropyLoss | |
from transformers.activations import ACT2FN | |
from transformers.modeling_outputs import MaskedLMOutput | |
from transformers.models.roberta.modeling_roberta import ( | |
RobertaLMHead, | |
RobertaModel, | |
RobertaPreTrainedModel, | |
) | |
from transformers.utils import logging | |
from sdlm.utils import convert_to_simplex, mix_values_based_on_self_condition | |
logger = logging.get_logger(__name__) | |
class RobertaForDiffusionLM(RobertaPreTrainedModel): | |
_keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"] | |
_keys_to_ignore_on_load_missing = [ | |
r"position_ids", | |
r"lm_head.decoder.weight", | |
r"lm_head.decoder.bias", | |
] | |
_keys_to_ignore_on_load_unexpected = [r"pooler"] | |
def __init__(self, config): | |
super().__init__(config) | |
if config.is_decoder: | |
logger.warning( | |
"If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for " | |
"bi-directional self-attention." | |
) | |
self.roberta = RobertaModel(config, add_pooling_layer=False) | |
self.lm_head = RobertaLMHead(config) | |
# # The LM head weights require special treatment only when they are tied with the word embeddings | |
# self.update_keys_to_ignore(config, ["lm_head.decoder.weight"]) | |
# self.vocab_to_hidden_dim_embed = nn.Linear( | |
# config.vocab_size, config.hidden_size, bias=False | |
# ) | |
self.timestep_embed = nn.Linear(1, config.hidden_size, bias=True) | |
if self.config.self_condition is not None and self.config.deepmind_conditional: | |
# In this case, this is self-conditioning with conditional generation as done in DeepMind paper. | |
# See Figure 3 in https://arxiv.org/pdf/2211.15089.pdf. | |
# Here we concat masked word embeddings, noisy embeddings, mask, and self-conditioning inputs | |
# and project them to the hidden_size. | |
self.project_to_hidden_size = nn.Linear( | |
config.hidden_size * 4, config.hidden_size, bias=False | |
) | |
elif ( | |
self.config.self_condition is not None | |
and not self.config.self_condition # noqa: E713 | |
in [ | |
"logits_addition", | |
"logits_with_projection_addition", | |
"logits_max", | |
"logits_mean", | |
] | |
): | |
if config.self_condition_mlp_projection: | |
self.project_to_hidden_size = nn.Sequential( | |
nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False), | |
ACT2FN[config.hidden_act], | |
nn.Linear(config.hidden_size, config.hidden_size, bias=False), | |
) | |
else: | |
self.project_to_hidden_size = nn.Linear( | |
config.hidden_size * 2, config.hidden_size, bias=False | |
) | |
# Initialize weights and apply final processing | |
self.post_init() | |
# run embedding matrix as linear layer | |
def vocab_to_hidden_dim_embed(self, input_data): | |
return F.linear(input_data, self.roberta.embeddings.word_embeddings.weight.T) | |
# def post_init(self): | |
# super().post_init() | |
# self.vocab_to_hidden_dim_embed.weight.data = ( | |
# self.get_input_embeddings().weight.data.T | |
# ) | |
# import pdb; pdb.set_trace() | |
def get_output_embeddings(self): | |
return self.lm_head.decoder | |
def set_output_embeddings(self, new_embeddings): | |
self.lm_head.decoder = new_embeddings | |
def get_roberta_empty_tokens(self, shape, device): | |
if self.config.empty_token_be_mask: | |
empty_token_ids = ( | |
torch.ones(shape, dtype=torch.int64, device=device) * 50264 | |
) | |
else: | |
# Padding token in roberta-large is 1. | |
empty_token_ids = torch.ones(shape, dtype=torch.int64, device=device) | |
empty_token_ids[:, 0] = 0 | |
empty_token_ids[:, -1] = 2 | |
return empty_token_ids | |
def forward( | |
self, | |
timesteps: torch.FloatTensor, | |
input_ids: torch.LongTensor, | |
simplex: torch.FloatTensor, | |
span_mask: Optional[torch.FloatTensor] = None, | |
token_type_ids: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
head_mask: Optional[torch.FloatTensor] = None, | |
encoder_hidden_states: Optional[torch.FloatTensor] = None, | |
encoder_attention_mask: Optional[torch.FloatTensor] = None, | |
labels: Optional[torch.LongTensor] = None, | |
output_attentions: Optional[bool] = None, | |
output_hidden_states: Optional[bool] = None, | |
return_dict: Optional[bool] = None, | |
previous_pred: Optional[torch.FloatTensor] = None, | |
classifier_free_guidance: bool = False, | |
classifier_free_guidance_in_train: bool = False, | |
max_timestep: int = 5000, | |
reduce_loss: str = "mean", # passed to 'reduction' in F.cross_entropy | |
# unconditional_simplex: torch.FloatTensor = None, | |
return_all_losses: bool = False, # return per-token loss for all items in batch | |
previous_hidden: Optional[torch.FloatTensor] = None, # for CDCD predictions... | |
original_timesteps: Optional[torch.FloatTensor] = None, | |
) -> Union[Tuple[torch.Tensor], MaskedLMOutput]: | |
r""" | |
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): | |
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ..., | |
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the | |
loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]` | |
kwargs (`Dict[str, any]`, optional, defaults to *{}*): | |
Used to hide legacy arguments that have been deprecated. | |
""" | |
return_dict = ( | |
return_dict if return_dict is not None else self.config.use_return_dict | |
) | |
# If we have a mask, we need to mask the simplex values before softmax. | |
""" | |
if span_mask is not None: | |
mask_value = torch.finfo(simplex.dtype).min | |
mask_value = torch.tensor(mask_value, dtype=simplex.dtype, device=simplex.device) | |
simplex = torch.where(span_mask[:, :, None], simplex, mask_value) | |
""" | |
inputs_probs = F.softmax(simplex, dim=-1) | |
inputs_embeds = self.vocab_to_hidden_dim_embed(inputs_probs) | |
if classifier_free_guidance or classifier_free_guidance_in_train: | |
if self.config.classifier_free_simplex_inputs: | |
if self.config.classifier_free_uncond_input == "empty_token": | |
empty_token_ids = self.get_roberta_empty_tokens( | |
shape=input_ids.shape, device=input_ids.device | |
) | |
# TODO: fix the simplex_value later. | |
unconditional_simplex = convert_to_simplex( | |
empty_token_ids, 5.0, self.config.vocab_size | |
) | |
elif self.config.classifier_free_uncond_input == "noisy_simplex": | |
simplex_shape = ( | |
input_ids.shape[0], | |
input_ids.shape[1], | |
self.config.vocab_size, | |
) | |
unconditional_simplex = 5.0 * torch.randn( | |
simplex_shape, device=input_ids.device | |
) | |
else: | |
raise NotImplementedError | |
unconditional_probs = F.softmax(unconditional_simplex, dim=-1) | |
uncond_inputs_embeds = self.vocab_to_hidden_dim_embed( | |
unconditional_probs | |
) | |
else: | |
empty_token_ids = self.get_roberta_empty_tokens( | |
shape=input_ids.shape, device=input_ids.device | |
) | |
uncond_inputs_embeds = self.get_input_embeddings()(empty_token_ids) | |
if self.config.self_condition is not None: | |
if self.config.self_condition_zeros_after_softmax and previous_pred is None: | |
previous_pred_probs = torch.zeros_like(simplex, device=simplex.device) | |
else: | |
if previous_pred is None: | |
previous_pred = torch.zeros_like(simplex, device=simplex.device) | |
""" | |
if span_mask is not None: | |
mask_value = torch.finfo(previous_pred.dtype).min | |
mask_value = torch.tensor(mask_value, dtype=previous_pred.dtype, device=previous_pred.device) | |
previous_pred = torch.where(span_mask[:, :, None], previous_pred, mask_value) | |
""" | |
previous_pred_probs = F.softmax(previous_pred, dim=-1) | |
if not self.config.self_condition_mix_logits_before_weights: | |
previous_pred = self.vocab_to_hidden_dim_embed(previous_pred_probs) | |
if not self.config.deepmind_conditional: | |
# In this setting, we mix the probabilities then apply the weight. | |
if self.config.self_condition_mix_logits_before_weights: | |
mixed_logits = mix_values_based_on_self_condition( | |
self.config.self_condition, simplex, previous_pred | |
) | |
mixed_probs = F.softmax(mixed_logits, dim=-1) | |
inputs_embeds = self.vocab_to_hidden_dim_embed(mixed_probs) | |
elif self.config.self_condition_mix_before_weights: | |
mixed_probs = mix_values_based_on_self_condition( | |
self.config.self_condition, inputs_probs, previous_pred_probs | |
) | |
inputs_embeds = self.vocab_to_hidden_dim_embed(mixed_probs) | |
else: | |
if self.config.self_condition in [ | |
"logits", | |
"logits_with_projection", | |
]: | |
inputs_embeds = self.project_to_hidden_size( | |
torch.cat([inputs_embeds, previous_pred], axis=-1) | |
) | |
else: | |
inputs_embeds = mix_values_based_on_self_condition( | |
self.config.self_condition, inputs_embeds, previous_pred | |
) | |
if span_mask is not None: | |
# Original word embeddings without noise. | |
if classifier_free_guidance_in_train and random.uniform(0, 1) < 0.1: | |
inputs_word_embeds = uncond_inputs_embeds | |
else: | |
inputs_word_embeds = self.get_input_embeddings()(input_ids) | |
if self.config.self_condition is not None and self.config.deepmind_conditional: | |
inputs_embeds = torch.where( | |
span_mask.unsqueeze(-1), inputs_embeds, torch.zeros_like(previous_pred) | |
) | |
previous_pred = torch.where( | |
span_mask.unsqueeze(-1), previous_pred, torch.zeros_like(previous_pred) | |
) | |
inputs_word_embeds = torch.where( | |
span_mask.unsqueeze(-1), | |
torch.zeros_like(inputs_word_embeds), | |
inputs_word_embeds, | |
) | |
tiled_mask = span_mask.unsqueeze(-1).repeat(1, 1, self.config.hidden_size) | |
inputs_embeds = self.project_to_hidden_size( | |
torch.cat( | |
[inputs_embeds, inputs_word_embeds, previous_pred, tiled_mask], | |
axis=-1, | |
) | |
) | |
bsz = input_ids.shape[0] | |
timesteps_embed = self.timestep_embed(timesteps.view(-1, 1).float()).view( | |
bsz, -1, self.config.hidden_size | |
) | |
inputs_embeds = inputs_embeds + timesteps_embed | |
if span_mask is not None and not self.config.deepmind_conditional: | |
# For the unmasked tokens, we only compute their original word embeddings. | |
# Note that this also sets the self-conditioned inputs wich we are conditioning on | |
# to their original word embeddings values. | |
inputs_embeds = torch.where( | |
span_mask.unsqueeze(-1), inputs_embeds, inputs_word_embeds | |
) | |
# TODO: we need to fix classifier-free guidance for the case of deepmind_conditional. | |
if classifier_free_guidance: | |
inputs_embeds = torch.cat([uncond_inputs_embeds, inputs_embeds]) | |
outputs = self.roberta( | |
input_ids=None, # TODO(rabeeh): we can remove this hack when we moved loss to outside. | |
attention_mask=None, # attention_mask, | |
token_type_ids=token_type_ids, | |
position_ids=position_ids, | |
head_mask=head_mask, | |
inputs_embeds=inputs_embeds, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
output_attentions=output_attentions, | |
output_hidden_states=output_hidden_states, | |
return_dict=return_dict, | |
) | |
sequence_output = outputs[0] | |
prediction_scores = self.lm_head(sequence_output) | |
# import pdb; pdb.set_trace() | |
masked_lm_loss = None | |
# In case of classifier-free guidance, since the number of output logits and input token ids do not match | |
# we do not compute the loss. | |
if input_ids is not None: | |
# In case of classifier_free guidance we need to get rid of the unconditional part. | |
prediction_scores_for_loss = ( | |
prediction_scores.chunk(2)[1] | |
if classifier_free_guidance | |
else prediction_scores | |
) | |
loss_fct = CrossEntropyLoss(reduction=reduce_loss) | |
labels = ( | |
torch.where(span_mask, input_ids, -100) | |
if span_mask is not None | |
else input_ids | |
) | |
if self.config.mask_padding_in_loss: | |
# also mask padding token loss.... | |
labels = torch.where(labels == self.config.pad_token_id, -100, labels) | |
masked_lm_loss = loss_fct( | |
prediction_scores_for_loss.view(-1, self.config.vocab_size), | |
labels.view(-1), | |
) | |
if return_all_losses: | |
all_lm_losses = masked_lm_loss.view(input_ids.shape[0], -1) | |
if reduce_loss == "none": | |
# take the average loss over tokens, not counting the masked tokens. | |
masked_lm_loss = masked_lm_loss.view(input_ids.shape[0], -1) | |
masked_lm_loss = masked_lm_loss.sum(dim=-1) / span_mask.sum(dim=-1) | |
if not return_dict: | |
output = (prediction_scores,) + outputs[2:] | |
return ( | |
((masked_lm_loss,) + output) if masked_lm_loss is not None else output | |
) | |
return MaskedLMOutput( | |
loss=all_lm_losses if return_all_losses else masked_lm_loss, | |
logits=prediction_scores, | |
hidden_states=outputs.last_hidden_state, | |
attentions=outputs.attentions, | |
) | |
def resize_position_embeddings( | |
self, new_num_position_embeddings: int, with_alternatation=False | |
): | |
""" | |
Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`. | |
Arguments: | |
new_num_position_embeddings (`int`): | |
The number of new position embedding matrix. If position embeddings are learned, increasing the size | |
will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the | |
end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the | |
size will add correct vectors at the end following the position encoding algorithm, whereas reducing | |
the size will remove vectors from the end. | |
""" | |
num_position_embeds_diff = ( | |
new_num_position_embeddings - self.config.max_position_embeddings | |
) | |
# no resizing needs to be done if the length stays the same | |
if num_position_embeds_diff == 0: | |
return | |
logger.info( | |
f"Setting `config.max_position_embeddings={new_num_position_embeddings}`..." | |
) | |
self.config.max_position_embeddings = new_num_position_embeddings | |
old_position_embeddings_weight = ( | |
self.roberta.embeddings.position_embeddings.weight.clone() | |
) | |
padding_idx = self.config.pad_token_id | |
self.roberta.embeddings.position_embeddings = nn.Embedding( | |
self.config.max_position_embeddings, | |
self.config.hidden_size, | |
padding_idx=padding_idx, | |
) | |
with torch.no_grad(): | |
if num_position_embeds_diff > 0: | |
self.roberta.embeddings.position_embeddings.weight[ | |
:-num_position_embeds_diff | |
] = nn.Parameter(old_position_embeddings_weight) | |
if with_alternatation: | |
self.roberta.embeddings.position_embeddings.weight[ | |
-num_position_embeds_diff: | |
] = nn.Parameter( | |
old_position_embeddings_weight[:num_position_embeds_diff] | |
) | |
else: | |
self.roberta.embeddings.position_embeddings.weight = nn.Parameter( | |
old_position_embeddings_weight[:num_position_embeds_diff] | |
) | |
# move position_embeddings to correct device | |
self.roberta.embeddings.position_embeddings.to(self.device) | |
# Update other needed parameters. | |
self.roberta.embeddings.position_ids = ( | |
torch.arange(self.config.max_position_embeddings) | |
.expand((1, -1)) | |
.type_as(self.roberta.embeddings.position_ids) | |
) | |
self.roberta.embeddings.token_type_ids = torch.zeros( | |
self.roberta.embeddings.position_ids.size(), dtype=torch.long | |
).type_as(self.roberta.embeddings.token_type_ids) | |
# resize the distance embeddings. | |
for i in range(self.config.num_hidden_layers): | |
if ( | |
self.config.position_embedding_type == "relative_key" | |
or self.config.position_embedding_type == "relative_key_query" | |
): | |
self.roberta.encoder.layer[ | |
i | |
].attention.self.distance_embedding = nn.Embedding( | |
2 * self.config.max_position_embeddings - 1, | |
self.attention_head_size, | |
) | |
old_distance_embedding_weight = self.layer[ | |
i | |
].attention.self.distance_embedding.weight.clone() | |
with torch.no_grad(): | |
if num_position_embeds_diff > 0: | |
self.roberta.encoder.layer[ | |
i | |
].attention.self.distance_embedding.weight[ | |
: -2 * num_position_embeds_diff | |
] = nn.Parameter( | |
old_distance_embedding_weight | |
) | |
else: | |
self.roberta.encoder.layer[ | |
i | |
].attention.self.distance_embedding.weight = nn.Parameter( | |
old_distance_embedding_weight[ | |
: 2 * num_position_embeds_diff | |
] | |
) | |