File size: 19,862 Bytes
17ff0d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
import random
from typing import Optional, Tuple, Union

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
from transformers.activations import ACT2FN
from transformers.modeling_outputs import MaskedLMOutput
from transformers.models.roberta.modeling_roberta import (
    RobertaLMHead,
    RobertaModel,
    RobertaPreTrainedModel,
)
from transformers.utils import logging

from sdlm.utils import convert_to_simplex, mix_values_based_on_self_condition

logger = logging.get_logger(__name__)


class RobertaForDiffusionLM(RobertaPreTrainedModel):
    _keys_to_ignore_on_save = [r"lm_head.decoder.weight", r"lm_head.decoder.bias"]
    _keys_to_ignore_on_load_missing = [
        r"position_ids",
        r"lm_head.decoder.weight",
        r"lm_head.decoder.bias",
    ]
    _keys_to_ignore_on_load_unexpected = [r"pooler"]

    def __init__(self, config):
        super().__init__(config)

        if config.is_decoder:
            logger.warning(
                "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
                "bi-directional self-attention."
            )

        self.roberta = RobertaModel(config, add_pooling_layer=False)
        self.lm_head = RobertaLMHead(config)

        # # The LM head weights require special treatment only when they are tied with the word embeddings
        # self.update_keys_to_ignore(config, ["lm_head.decoder.weight"])

        # self.vocab_to_hidden_dim_embed = nn.Linear(
        #     config.vocab_size, config.hidden_size, bias=False
        # )
        self.timestep_embed = nn.Linear(1, config.hidden_size, bias=True)

        if self.config.self_condition is not None and self.config.deepmind_conditional:
            # In this case, this is self-conditioning with conditional generation as done in DeepMind paper.
            # See Figure 3 in https://arxiv.org/pdf/2211.15089.pdf.
            # Here we concat masked word embeddings, noisy embeddings, mask, and self-conditioning inputs
            # and project them to the hidden_size.
            self.project_to_hidden_size = nn.Linear(
                config.hidden_size * 4, config.hidden_size, bias=False
            )
        elif (
            self.config.self_condition is not None
            and not self.config.self_condition  # noqa: E713
            in [
                "logits_addition",
                "logits_with_projection_addition",
                "logits_max",
                "logits_mean",
            ]
        ):
            if config.self_condition_mlp_projection:
                self.project_to_hidden_size = nn.Sequential(
                    nn.Linear(config.hidden_size * 2, config.hidden_size, bias=False),
                    ACT2FN[config.hidden_act],
                    nn.Linear(config.hidden_size, config.hidden_size, bias=False),
                )
            else:
                self.project_to_hidden_size = nn.Linear(
                    config.hidden_size * 2, config.hidden_size, bias=False
                )

        # Initialize weights and apply final processing
        self.post_init()

    # run embedding matrix as linear layer
    def vocab_to_hidden_dim_embed(self, input_data):
        return F.linear(input_data, self.roberta.embeddings.word_embeddings.weight.T)

    # def post_init(self):
    #     super().post_init()
    #     self.vocab_to_hidden_dim_embed.weight.data = (
    #         self.get_input_embeddings().weight.data.T
    #     )
    #     import pdb; pdb.set_trace()

    def get_output_embeddings(self):
        return self.lm_head.decoder

    def set_output_embeddings(self, new_embeddings):
        self.lm_head.decoder = new_embeddings

    def get_roberta_empty_tokens(self, shape, device):
        if self.config.empty_token_be_mask:
            empty_token_ids = (
                torch.ones(shape, dtype=torch.int64, device=device) * 50264
            )
        else:
            # Padding token in roberta-large is 1.
            empty_token_ids = torch.ones(shape, dtype=torch.int64, device=device)
        empty_token_ids[:, 0] = 0
        empty_token_ids[:, -1] = 2
        return empty_token_ids

    def forward(
        self,
        timesteps: torch.FloatTensor,
        input_ids: torch.LongTensor,
        simplex: torch.FloatTensor,
        span_mask: Optional[torch.FloatTensor] = None,
        token_type_ids: Optional[torch.LongTensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        previous_pred: Optional[torch.FloatTensor] = None,
        classifier_free_guidance: bool = False,
        classifier_free_guidance_in_train: bool = False,
        max_timestep: int = 5000,
        reduce_loss: str = "mean",  # passed to 'reduction' in F.cross_entropy
        # unconditional_simplex: torch.FloatTensor = None,
        return_all_losses: bool = False,  # return per-token loss for all items in batch
        previous_hidden: Optional[torch.FloatTensor] = None,  # for CDCD predictions...
        original_timesteps: Optional[torch.FloatTensor] = None,
    ) -> Union[Tuple[torch.Tensor], MaskedLMOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
            config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
            loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
        kwargs (`Dict[str, any]`, optional, defaults to *{}*):
            Used to hide legacy arguments that have been deprecated.
        """
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # If we have a mask, we need to mask the simplex values before softmax.
        """
        if span_mask is not None:
            mask_value = torch.finfo(simplex.dtype).min
            mask_value = torch.tensor(mask_value, dtype=simplex.dtype, device=simplex.device)
            simplex = torch.where(span_mask[:, :, None], simplex, mask_value)
        """
        inputs_probs = F.softmax(simplex, dim=-1)
        inputs_embeds = self.vocab_to_hidden_dim_embed(inputs_probs)

        if classifier_free_guidance or classifier_free_guidance_in_train:
            if self.config.classifier_free_simplex_inputs:
                if self.config.classifier_free_uncond_input == "empty_token":
                    empty_token_ids = self.get_roberta_empty_tokens(
                        shape=input_ids.shape, device=input_ids.device
                    )
                    # TODO: fix the simplex_value later.
                    unconditional_simplex = convert_to_simplex(
                        empty_token_ids, 5.0, self.config.vocab_size
                    )
                elif self.config.classifier_free_uncond_input == "noisy_simplex":
                    simplex_shape = (
                        input_ids.shape[0],
                        input_ids.shape[1],
                        self.config.vocab_size,
                    )
                    unconditional_simplex = 5.0 * torch.randn(
                        simplex_shape, device=input_ids.device
                    )
                else:
                    raise NotImplementedError
                unconditional_probs = F.softmax(unconditional_simplex, dim=-1)
                uncond_inputs_embeds = self.vocab_to_hidden_dim_embed(
                    unconditional_probs
                )
            else:
                empty_token_ids = self.get_roberta_empty_tokens(
                    shape=input_ids.shape, device=input_ids.device
                )
                uncond_inputs_embeds = self.get_input_embeddings()(empty_token_ids)

        if self.config.self_condition is not None:
            if self.config.self_condition_zeros_after_softmax and previous_pred is None:
                previous_pred_probs = torch.zeros_like(simplex, device=simplex.device)
            else:
                if previous_pred is None:
                    previous_pred = torch.zeros_like(simplex, device=simplex.device)
                """
                if span_mask is not None:
                    mask_value = torch.finfo(previous_pred.dtype).min
                    mask_value = torch.tensor(mask_value, dtype=previous_pred.dtype, device=previous_pred.device)
                    previous_pred = torch.where(span_mask[:, :, None], previous_pred, mask_value)
                """
                previous_pred_probs = F.softmax(previous_pred, dim=-1)
            if not self.config.self_condition_mix_logits_before_weights:
                previous_pred = self.vocab_to_hidden_dim_embed(previous_pred_probs)
            if not self.config.deepmind_conditional:
                # In this setting, we mix the probabilities then apply the weight.
                if self.config.self_condition_mix_logits_before_weights:
                    mixed_logits = mix_values_based_on_self_condition(
                        self.config.self_condition, simplex, previous_pred
                    )
                    mixed_probs = F.softmax(mixed_logits, dim=-1)
                    inputs_embeds = self.vocab_to_hidden_dim_embed(mixed_probs)
                elif self.config.self_condition_mix_before_weights:
                    mixed_probs = mix_values_based_on_self_condition(
                        self.config.self_condition, inputs_probs, previous_pred_probs
                    )
                    inputs_embeds = self.vocab_to_hidden_dim_embed(mixed_probs)
                else:
                    if self.config.self_condition in [
                        "logits",
                        "logits_with_projection",
                    ]:
                        inputs_embeds = self.project_to_hidden_size(
                            torch.cat([inputs_embeds, previous_pred], axis=-1)
                        )
                    else:
                        inputs_embeds = mix_values_based_on_self_condition(
                            self.config.self_condition, inputs_embeds, previous_pred
                        )

        if span_mask is not None:
            # Original word embeddings without noise.
            if classifier_free_guidance_in_train and random.uniform(0, 1) < 0.1:
                inputs_word_embeds = uncond_inputs_embeds
            else:
                inputs_word_embeds = self.get_input_embeddings()(input_ids)

        if self.config.self_condition is not None and self.config.deepmind_conditional:
            inputs_embeds = torch.where(
                span_mask.unsqueeze(-1), inputs_embeds, torch.zeros_like(previous_pred)
            )
            previous_pred = torch.where(
                span_mask.unsqueeze(-1), previous_pred, torch.zeros_like(previous_pred)
            )
            inputs_word_embeds = torch.where(
                span_mask.unsqueeze(-1),
                torch.zeros_like(inputs_word_embeds),
                inputs_word_embeds,
            )
            tiled_mask = span_mask.unsqueeze(-1).repeat(1, 1, self.config.hidden_size)
            inputs_embeds = self.project_to_hidden_size(
                torch.cat(
                    [inputs_embeds, inputs_word_embeds, previous_pred, tiled_mask],
                    axis=-1,
                )
            )

        bsz = input_ids.shape[0]
        timesteps_embed = self.timestep_embed(timesteps.view(-1, 1).float()).view(
            bsz, -1, self.config.hidden_size
        )
        inputs_embeds = inputs_embeds + timesteps_embed

        if span_mask is not None and not self.config.deepmind_conditional:
            # For the unmasked tokens, we only compute their original word embeddings.
            # Note that this also sets the self-conditioned inputs wich we are conditioning on
            # to their original word embeddings values.
            inputs_embeds = torch.where(
                span_mask.unsqueeze(-1), inputs_embeds, inputs_word_embeds
            )
            # TODO: we need to fix classifier-free guidance for the case of deepmind_conditional.
            if classifier_free_guidance:
                inputs_embeds = torch.cat([uncond_inputs_embeds, inputs_embeds])
        outputs = self.roberta(
            input_ids=None,  # TODO(rabeeh): we can remove this hack when we moved loss to outside.
            attention_mask=None,  # attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            encoder_hidden_states=encoder_hidden_states,
            encoder_attention_mask=encoder_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        sequence_output = outputs[0]
        prediction_scores = self.lm_head(sequence_output)
        # import pdb; pdb.set_trace()

        masked_lm_loss = None
        # In case of classifier-free guidance, since the number of output logits and input token ids do not match
        # we do not compute the loss.
        if input_ids is not None:
            # In case of classifier_free guidance we need to get rid of the unconditional part.
            prediction_scores_for_loss = (
                prediction_scores.chunk(2)[1]
                if classifier_free_guidance
                else prediction_scores
            )
            loss_fct = CrossEntropyLoss(reduction=reduce_loss)
            labels = (
                torch.where(span_mask, input_ids, -100)
                if span_mask is not None
                else input_ids
            )
            if self.config.mask_padding_in_loss:
                # also mask padding token loss....
                labels = torch.where(labels == self.config.pad_token_id, -100, labels)
            masked_lm_loss = loss_fct(
                prediction_scores_for_loss.view(-1, self.config.vocab_size),
                labels.view(-1),
            )
            if return_all_losses:
                all_lm_losses = masked_lm_loss.view(input_ids.shape[0], -1)
            if reduce_loss == "none":
                # take the average loss over tokens, not counting the masked tokens.
                masked_lm_loss = masked_lm_loss.view(input_ids.shape[0], -1)
                masked_lm_loss = masked_lm_loss.sum(dim=-1) / span_mask.sum(dim=-1)

        if not return_dict:
            output = (prediction_scores,) + outputs[2:]
            return (
                ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
            )

        return MaskedLMOutput(
            loss=all_lm_losses if return_all_losses else masked_lm_loss,
            logits=prediction_scores,
            hidden_states=outputs.last_hidden_state,
            attentions=outputs.attentions,
        )

    def resize_position_embeddings(
        self, new_num_position_embeddings: int, with_alternatation=False
    ):
        """
        Resizes position embeddings of the model if `new_num_position_embeddings != config.max_position_embeddings`.
        Arguments:
            new_num_position_embeddings (`int`):
                The number of new position embedding matrix. If position embeddings are learned, increasing the size
                will add newly initialized vectors at the end, whereas reducing the size will remove vectors from the
                end. If position embeddings are not learned (*e.g.* sinusoidal position embeddings), increasing the
                size will add correct vectors at the end following the position encoding algorithm, whereas reducing
                the size will remove vectors from the end.
        """
        num_position_embeds_diff = (
            new_num_position_embeddings - self.config.max_position_embeddings
        )

        # no resizing needs to be done if the length stays the same
        if num_position_embeds_diff == 0:
            return

        logger.info(
            f"Setting `config.max_position_embeddings={new_num_position_embeddings}`..."
        )
        self.config.max_position_embeddings = new_num_position_embeddings
        old_position_embeddings_weight = (
            self.roberta.embeddings.position_embeddings.weight.clone()
        )

        padding_idx = self.config.pad_token_id
        self.roberta.embeddings.position_embeddings = nn.Embedding(
            self.config.max_position_embeddings,
            self.config.hidden_size,
            padding_idx=padding_idx,
        )
        with torch.no_grad():
            if num_position_embeds_diff > 0:
                self.roberta.embeddings.position_embeddings.weight[
                    :-num_position_embeds_diff
                ] = nn.Parameter(old_position_embeddings_weight)
                if with_alternatation:
                    self.roberta.embeddings.position_embeddings.weight[
                        -num_position_embeds_diff:
                    ] = nn.Parameter(
                        old_position_embeddings_weight[:num_position_embeds_diff]
                    )
            else:
                self.roberta.embeddings.position_embeddings.weight = nn.Parameter(
                    old_position_embeddings_weight[:num_position_embeds_diff]
                )
        # move position_embeddings to correct device
        self.roberta.embeddings.position_embeddings.to(self.device)
        # Update other needed parameters.
        self.roberta.embeddings.position_ids = (
            torch.arange(self.config.max_position_embeddings)
            .expand((1, -1))
            .type_as(self.roberta.embeddings.position_ids)
        )
        self.roberta.embeddings.token_type_ids = torch.zeros(
            self.roberta.embeddings.position_ids.size(), dtype=torch.long
        ).type_as(self.roberta.embeddings.token_type_ids)

        # resize the distance embeddings.
        for i in range(self.config.num_hidden_layers):
            if (
                self.config.position_embedding_type == "relative_key"
                or self.config.position_embedding_type == "relative_key_query"
            ):
                self.roberta.encoder.layer[
                    i
                ].attention.self.distance_embedding = nn.Embedding(
                    2 * self.config.max_position_embeddings - 1,
                    self.attention_head_size,
                )
                old_distance_embedding_weight = self.layer[
                    i
                ].attention.self.distance_embedding.weight.clone()
                with torch.no_grad():
                    if num_position_embeds_diff > 0:
                        self.roberta.encoder.layer[
                            i
                        ].attention.self.distance_embedding.weight[
                            : -2 * num_position_embeds_diff
                        ] = nn.Parameter(
                            old_distance_embedding_weight
                        )
                    else:
                        self.roberta.encoder.layer[
                            i
                        ].attention.self.distance_embedding.weight = nn.Parameter(
                            old_distance_embedding_weight[
                                : 2 * num_position_embeds_diff
                            ]
                        )