""" Copyright (c) Meta Platforms, Inc. and affiliates. All rights reserved. This source code is licensed under the license found in the LICENSE file in the root directory of this source tree. """ import math from typing import Any, Callable, List, Optional, Union import torch import torch.nn as nn from einops import rearrange from torch import Tensor from torch.nn import functional as F def generate_causal_mask(source_length, target_length, device="cpu"): if source_length == target_length: mask = ( torch.triu(torch.ones(target_length, source_length, device=device)) == 1 ).transpose(0, 1) else: mask = torch.zeros(target_length, source_length, device=device) idx = torch.linspace(0, source_length, target_length + 1)[1:].round().long() for i in range(target_length): mask[i, 0 : idx[i]] = 1 return ( mask.float() .masked_fill(mask == 0, float("-inf")) .masked_fill(mask == 1, float(0.0)) ) class TransformerEncoderLayerRotary(nn.Module): def __init__( self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation: Union[str, Callable[[Tensor], Tensor]] = F.relu, layer_norm_eps: float = 1e-5, batch_first: bool = False, norm_first: bool = True, rotary=None, ) -> None: super().__init__() self.self_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first ) # Implementation of Feedforward model self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm_first = norm_first self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.activation = activation self.rotary = rotary self.use_rotary = rotary is not None def forward( self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None, ) -> Tensor: x = src if self.norm_first: x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) x = x + self._ff_block(self.norm2(x)) else: x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask)) x = self.norm2(x + self._ff_block(x)) return x # self-attention block def _sa_block( self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor] ) -> Tensor: qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x x = self.self_attn( qk, qk, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout1(x) # feed forward block def _ff_block(self, x: Tensor) -> Tensor: x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout2(x) class DenseFiLM(nn.Module): """Feature-wise linear modulation (FiLM) generator.""" def __init__(self, embed_channels): super().__init__() self.embed_channels = embed_channels self.block = nn.Sequential( nn.Mish(), nn.Linear(embed_channels, embed_channels * 2) ) def forward(self, position): pos_encoding = self.block(position) pos_encoding = rearrange(pos_encoding, "b c -> b 1 c") scale_shift = pos_encoding.chunk(2, dim=-1) return scale_shift def featurewise_affine(x, scale_shift): scale, shift = scale_shift return (scale + 1) * x + shift class FiLMTransformerDecoderLayer(nn.Module): def __init__( self, d_model: int, nhead: int, dim_feedforward=2048, dropout=0.1, activation=F.relu, layer_norm_eps=1e-5, batch_first=False, norm_first=True, rotary=None, use_cm=False, ): super().__init__() self.self_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first ) self.multihead_attn = nn.MultiheadAttention( d_model, nhead, dropout=dropout, batch_first=batch_first ) # Feedforward self.linear1 = nn.Linear(d_model, dim_feedforward) self.dropout = nn.Dropout(dropout) self.linear2 = nn.Linear(dim_feedforward, d_model) self.norm_first = norm_first self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) self.activation = activation self.film1 = DenseFiLM(d_model) self.film2 = DenseFiLM(d_model) self.film3 = DenseFiLM(d_model) if use_cm: self.multihead_attn2 = nn.MultiheadAttention( # 2 d_model, nhead, dropout=dropout, batch_first=batch_first ) self.norm2a = nn.LayerNorm(d_model, eps=layer_norm_eps) # 2 self.dropout2a = nn.Dropout(dropout) # 2 self.film2a = DenseFiLM(d_model) # 2 self.rotary = rotary self.use_rotary = rotary is not None # x, cond, t def forward( self, tgt, memory, t, tgt_mask=None, memory_mask=None, tgt_key_padding_mask=None, memory_key_padding_mask=None, memory2=None, ): x = tgt if self.norm_first: # self-attention -> film -> residual x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask) x = x + featurewise_affine(x_1, self.film1(t)) # cross-attention -> film -> residual x_2 = self._mha_block( self.norm2(x), memory, memory_mask, memory_key_padding_mask, self.multihead_attn, self.dropout2, ) x = x + featurewise_affine(x_2, self.film2(t)) if memory2 is not None: # cross-attention x2 -> film -> residual x_2a = self._mha_block( self.norm2a(x), memory2, memory_mask, memory_key_padding_mask, self.multihead_attn2, self.dropout2a, ) x = x + featurewise_affine(x_2a, self.film2a(t)) # feedforward -> film -> residual x_3 = self._ff_block(self.norm3(x)) x = x + featurewise_affine(x_3, self.film3(t)) else: x = self.norm1( x + featurewise_affine( self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t) ) ) x = self.norm2( x + featurewise_affine( self._mha_block(x, memory, memory_mask, memory_key_padding_mask), self.film2(t), ) ) x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t))) return x # self-attention block # qkv def _sa_block(self, x, attn_mask, key_padding_mask): qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x x = self.self_attn( qk, qk, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return self.dropout1(x) # multihead attention block # qkv def _mha_block(self, x, mem, attn_mask, key_padding_mask, mha, dropout): q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem x = mha( q, k, mem, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] return dropout(x) # feed forward block def _ff_block(self, x): x = self.linear2(self.dropout(self.activation(self.linear1(x)))) return self.dropout3(x) class DecoderLayerStack(nn.Module): def __init__(self, stack): super().__init__() self.stack = stack def forward(self, x, cond, t, tgt_mask=None, memory2=None): for layer in self.stack: x = layer(x, cond, t, tgt_mask=tgt_mask, memory2=memory2) return x class PositionalEncoding(nn.Module): def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024): super().__init__() pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) self.dropout = nn.Dropout(p=dropout) def forward(self, x: torch.Tensor): """ :param x: B x T x d_model tensor :return: B x T x d_model tensor """ x = x + self.pe[None, : x.shape[1], :] x = self.dropout(x) return x class TimestepEncoding(nn.Module): def __init__(self, embedding_dim: int): super().__init__() # Fourier embedding half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim) * -emb) self.register_buffer("emb", emb) # encoding self.encoding = nn.Sequential( nn.Linear(embedding_dim, 4 * embedding_dim), nn.Mish(), nn.Linear(4 * embedding_dim, embedding_dim), ) def forward(self, t: torch.Tensor): """ :param t: B-dimensional tensor containing timesteps in range [0, 1] :return: B x embedding_dim tensor containing timestep encodings """ x = t[:, None] * self.emb[None, :] x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1) x = self.encoding(x) return x class FiLM(nn.Module): def __init__(self, dim: int): super().__init__() self.dim = dim self.film = nn.Sequential(nn.Mish(), nn.Linear(dim, dim * 2)) def forward(self, x: torch.Tensor, cond: torch.Tensor): """ :param x: ... x dim tensor :param cond: ... x dim tensor :return: ... x dim tensor as scale(cond) * x + bias(cond) """ cond = self.film(cond) scale, bias = torch.chunk(cond, chunks=2, dim=-1) x = (scale + 1) * x + bias return x class FeedforwardBlock(nn.Module): def __init__(self, d_model: int, d_feedforward: int = 1024, dropout: float = 0.1): super().__init__() self.ff = nn.Sequential( nn.Linear(d_model, d_feedforward), nn.ReLU(), nn.Dropout(p=dropout), nn.Linear(d_feedforward, d_model), nn.Dropout(p=dropout), ) def forward(self, x: torch.Tensor): """ :param x: ... x d_model tensor :return: ... x d_model tensor """ return self.ff(x) class SelfAttention(nn.Module): def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1): super().__init__() self.self_attn = nn.MultiheadAttention( d_model, num_heads, dropout=dropout, batch_first=True ) self.dropout = nn.Dropout(p=dropout) def forward( self, x: torch.Tensor, attn_mask: torch.Tensor = None, key_padding_mask: torch.Tensor = None, ): """ :param x: B x T x d_model input tensor :param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length for a float mask: values will be added to attention weight for a binary mask: True indicates that the element is not allowed to attend :param key_padding_mask: B x S mask for a float mask: values will be added directly to the corresponding key values for a binary mask: True indicates that the corresponding key value will be ignored :return: B x T x d_model output tensor """ x = self.self_attn( x, x, x, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] x = self.dropout(x) return x class CrossAttention(nn.Module): def __init__(self, d_model: int, d_cond: int, num_heads: int, dropout: float = 0.1): super().__init__() self.cross_attn = nn.MultiheadAttention( d_model, num_heads, dropout=dropout, batch_first=True, kdim=d_cond, vdim=d_cond, ) self.dropout = nn.Dropout(p=dropout) def forward( self, x: torch.Tensor, cond: torch.Tensor, attn_mask: torch.Tensor = None, key_padding_mask: torch.Tensor = None, ): """ :param x: B x T_target x d_model input tensor :param cond: B x T_cond x d_cond condition tensor :param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length for a float mask: values will be added to attention weight for a binary mask: True indicates that the element is not allowed to attend :param key_padding_mask: B x S mask for a float mask: values will be added directly to the corresponding key values for a binary mask: True indicates that the corresponding key value will be ignored :return: B x T x d_model output tensor """ x = self.cross_attn( x, cond, cond, attn_mask=attn_mask, key_padding_mask=key_padding_mask, need_weights=False, )[0] x = self.dropout(x) return x class TransformerEncoderLayer(nn.Module): def __init__( self, d_model: int, num_heads: int, d_feedforward: int = 1024, dropout: float = 0.1, ): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.self_attn = SelfAttention(d_model, num_heads, dropout) self.norm2 = nn.LayerNorm(d_model) self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout) def forward( self, x: torch.Tensor, mask: torch.Tensor = None, key_padding_mask: torch.Tensor = None, ): x = x + self.self_attn(self.norm1(x), mask, key_padding_mask) x = x + self.feedforward(self.norm2(x)) return x class TransformerDecoderLayer(nn.Module): def __init__( self, d_model: int, d_cond: int, num_heads: int, d_feedforward: int = 1024, dropout: float = 0.1, ): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.self_attn = SelfAttention(d_model, num_heads, dropout) self.norm2 = nn.LayerNorm(d_model) self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout) self.norm3 = nn.LayerNorm(d_model) self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout) def forward( self, x: torch.Tensor, cross_cond: torch.Tensor, target_mask: torch.Tensor = None, target_key_padding_mask: torch.Tensor = None, cross_cond_mask: torch.Tensor = None, cross_cond_key_padding_mask: torch.Tensor = None, ): """ :param x: B x T x d_model tensor :param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers :return: B x T x d_model tensor """ x = x + self.self_attn(self.norm1(x), target_mask, target_key_padding_mask) x = x + self.cross_attn( self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask ) x = x + self.feedforward(self.norm3(x)) return x class FilmTransformerDecoderLayer(nn.Module): def __init__( self, d_model: int, d_cond: int, num_heads: int, d_feedforward: int = 1024, dropout: float = 0.1, ): super().__init__() self.norm1 = nn.LayerNorm(d_model) self.self_attn = SelfAttention(d_model, num_heads, dropout) self.film1 = FiLM(d_model) self.norm2 = nn.LayerNorm(d_model) self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout) self.film2 = FiLM(d_model) self.norm3 = nn.LayerNorm(d_model) self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout) self.film3 = FiLM(d_model) def forward( self, x: torch.Tensor, cross_cond: torch.Tensor, film_cond: torch.Tensor, target_mask: torch.Tensor = None, target_key_padding_mask: torch.Tensor = None, cross_cond_mask: torch.Tensor = None, cross_cond_key_padding_mask: torch.Tensor = None, ): """ :param x: B x T x d_model tensor :param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers :param film_cond: B x [1 or T] x film_cond tensor containing the conditioning input to FiLM layers :return: B x T x d_model tensor """ x1 = self.self_attn(self.norm1(x), target_mask, target_key_padding_mask) x = x + self.film1(x1, film_cond) x2 = self.cross_attn( self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask ) x = x + self.film2(x2, film_cond) x3 = self.feedforward(self.norm3(x)) x = x + self.film3(x3, film_cond) return x class RegressionTransformer(nn.Module): def __init__( self, transformer_encoder_layers: int = 2, transformer_decoder_layers: int = 4, d_model: int = 512, d_cond: int = 512, num_heads: int = 4, d_feedforward: int = 1024, dropout: float = 0.1, causal: bool = False, ): super().__init__() self.causal = causal self.cond_positional_encoding = PositionalEncoding(d_cond, dropout) self.target_positional_encoding = PositionalEncoding(d_model, dropout) self.transformer_encoder = nn.ModuleList( [ TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout) for _ in range(transformer_encoder_layers) ] ) self.transformer_decoder = nn.ModuleList( [ TransformerDecoderLayer( d_model, d_cond, num_heads, d_feedforward, dropout ) for _ in range(transformer_decoder_layers) ] ) def forward(self, x: torch.Tensor, cond: torch.Tensor): """ :param x: B x T x d_model input tensor :param cond: B x T x d_cond conditional tensor :return: B x T x d_model output tensor """ x = self.target_positional_encoding(x) cond = self.cond_positional_encoding(cond) if self.causal: encoder_mask = generate_causal_mask( cond.shape[1], cond.shape[1], device=cond.device ) decoder_self_attn_mask = generate_causal_mask( x.shape[1], x.shape[1], device=x.device ) decoder_cross_attn_mask = generate_causal_mask( cond.shape[1], x.shape[1], device=x.device ) else: encoder_mask = None decoder_self_attn_mask = None decoder_cross_attn_mask = None for encoder_layer in self.transformer_encoder: cond = encoder_layer(cond, mask=encoder_mask) for decoder_layer in self.transformer_decoder: x = decoder_layer( x, cond, target_mask=decoder_self_attn_mask, cross_cond_mask=decoder_cross_attn_mask, ) return x class DiffusionTransformer(nn.Module): def __init__( self, transformer_encoder_layers: int = 2, transformer_decoder_layers: int = 4, d_model: int = 512, d_cond: int = 512, num_heads: int = 4, d_feedforward: int = 1024, dropout: float = 0.1, causal: bool = False, ): super().__init__() self.causal = causal self.timestep_encoder = TimestepEncoding(d_model) self.cond_positional_encoding = PositionalEncoding(d_cond, dropout) self.target_positional_encoding = PositionalEncoding(d_model, dropout) self.transformer_encoder = nn.ModuleList( [ TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout) for _ in range(transformer_encoder_layers) ] ) self.transformer_decoder = nn.ModuleList( [ FilmTransformerDecoderLayer( d_model, d_cond, num_heads, d_feedforward, dropout ) for _ in range(transformer_decoder_layers) ] ) def forward(self, x: torch.Tensor, cond: torch.Tensor, t: torch.Tensor): """ :param x: B x T x d_model input tensor :param cond: B x T x d_cond conditional tensor :param t: B-dimensional tensor containing diffusion timesteps in range [0, 1] :return: B x T x d_model output tensor """ t = self.timestep_encoder(t).unsqueeze(1) # B x 1 x d_model x = self.target_positional_encoding(x) cond = self.cond_positional_encoding(cond) if self.causal: encoder_mask = generate_causal_mask( cond.shape[1], cond.shape[1], device=cond.device ) decoder_self_attn_mask = generate_causal_mask( x.shape[1], x.shape[1], device=x.device ) decoder_cross_attn_mask = generate_causal_mask( cond.shape[1], x.shape[1], device=x.device ) else: encoder_mask = None decoder_self_attn_mask = None decoder_cross_attn_mask = None for encoder_layer in self.transformer_encoder: cond = encoder_layer(cond, mask=encoder_mask) for decoder_layer in self.transformer_decoder: x = decoder_layer( x, cond, t, target_mask=decoder_self_attn_mask, cross_cond_mask=decoder_cross_attn_mask, ) return x