test_virtual / model /modules /transformer_modules.py
whyun13's picture
Upload folder using huggingface_hub
882f6e2 verified
"""
Copyright (c) Meta Platforms, Inc. and affiliates.
All rights reserved.
This source code is licensed under the license found in the
LICENSE file in the root directory of this source tree.
"""
import math
from typing import Any, Callable, List, Optional, Union
import torch
import torch.nn as nn
from einops import rearrange
from torch import Tensor
from torch.nn import functional as F
def generate_causal_mask(source_length, target_length, device="cpu"):
if source_length == target_length:
mask = (
torch.triu(torch.ones(target_length, source_length, device=device)) == 1
).transpose(0, 1)
else:
mask = torch.zeros(target_length, source_length, device=device)
idx = torch.linspace(0, source_length, target_length + 1)[1:].round().long()
for i in range(target_length):
mask[i, 0 : idx[i]] = 1
return (
mask.float()
.masked_fill(mask == 0, float("-inf"))
.masked_fill(mask == 1, float(0.0))
)
class TransformerEncoderLayerRotary(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward: int = 2048,
dropout: float = 0.1,
activation: Union[str, Callable[[Tensor], Tensor]] = F.relu,
layer_norm_eps: float = 1e-5,
batch_first: bool = False,
norm_first: bool = True,
rotary=None,
) -> None:
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first
)
# Implementation of Feedforward model
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.activation = activation
self.rotary = rotary
self.use_rotary = rotary is not None
def forward(
self,
src: Tensor,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
) -> Tensor:
x = src
if self.norm_first:
x = x + self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)
x = x + self._ff_block(self.norm2(x))
else:
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
x = self.norm2(x + self._ff_block(x))
return x
# self-attention block
def _sa_block(
self, x: Tensor, attn_mask: Optional[Tensor], key_padding_mask: Optional[Tensor]
) -> Tensor:
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
x = self.self_attn(
qk,
qk,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# feed forward block
def _ff_block(self, x: Tensor) -> Tensor:
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout2(x)
class DenseFiLM(nn.Module):
"""Feature-wise linear modulation (FiLM) generator."""
def __init__(self, embed_channels):
super().__init__()
self.embed_channels = embed_channels
self.block = nn.Sequential(
nn.Mish(), nn.Linear(embed_channels, embed_channels * 2)
)
def forward(self, position):
pos_encoding = self.block(position)
pos_encoding = rearrange(pos_encoding, "b c -> b 1 c")
scale_shift = pos_encoding.chunk(2, dim=-1)
return scale_shift
def featurewise_affine(x, scale_shift):
scale, shift = scale_shift
return (scale + 1) * x + shift
class FiLMTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model: int,
nhead: int,
dim_feedforward=2048,
dropout=0.1,
activation=F.relu,
layer_norm_eps=1e-5,
batch_first=False,
norm_first=True,
rotary=None,
use_cm=False,
):
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first
)
self.multihead_attn = nn.MultiheadAttention(
d_model, nhead, dropout=dropout, batch_first=batch_first
)
# Feedforward
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm_first = norm_first
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = activation
self.film1 = DenseFiLM(d_model)
self.film2 = DenseFiLM(d_model)
self.film3 = DenseFiLM(d_model)
if use_cm:
self.multihead_attn2 = nn.MultiheadAttention( # 2
d_model, nhead, dropout=dropout, batch_first=batch_first
)
self.norm2a = nn.LayerNorm(d_model, eps=layer_norm_eps) # 2
self.dropout2a = nn.Dropout(dropout) # 2
self.film2a = DenseFiLM(d_model) # 2
self.rotary = rotary
self.use_rotary = rotary is not None
# x, cond, t
def forward(
self,
tgt,
memory,
t,
tgt_mask=None,
memory_mask=None,
tgt_key_padding_mask=None,
memory_key_padding_mask=None,
memory2=None,
):
x = tgt
if self.norm_first:
# self-attention -> film -> residual
x_1 = self._sa_block(self.norm1(x), tgt_mask, tgt_key_padding_mask)
x = x + featurewise_affine(x_1, self.film1(t))
# cross-attention -> film -> residual
x_2 = self._mha_block(
self.norm2(x),
memory,
memory_mask,
memory_key_padding_mask,
self.multihead_attn,
self.dropout2,
)
x = x + featurewise_affine(x_2, self.film2(t))
if memory2 is not None:
# cross-attention x2 -> film -> residual
x_2a = self._mha_block(
self.norm2a(x),
memory2,
memory_mask,
memory_key_padding_mask,
self.multihead_attn2,
self.dropout2a,
)
x = x + featurewise_affine(x_2a, self.film2a(t))
# feedforward -> film -> residual
x_3 = self._ff_block(self.norm3(x))
x = x + featurewise_affine(x_3, self.film3(t))
else:
x = self.norm1(
x
+ featurewise_affine(
self._sa_block(x, tgt_mask, tgt_key_padding_mask), self.film1(t)
)
)
x = self.norm2(
x
+ featurewise_affine(
self._mha_block(x, memory, memory_mask, memory_key_padding_mask),
self.film2(t),
)
)
x = self.norm3(x + featurewise_affine(self._ff_block(x), self.film3(t)))
return x
# self-attention block
# qkv
def _sa_block(self, x, attn_mask, key_padding_mask):
qk = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
x = self.self_attn(
qk,
qk,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return self.dropout1(x)
# multihead attention block
# qkv
def _mha_block(self, x, mem, attn_mask, key_padding_mask, mha, dropout):
q = self.rotary.rotate_queries_or_keys(x) if self.use_rotary else x
k = self.rotary.rotate_queries_or_keys(mem) if self.use_rotary else mem
x = mha(
q,
k,
mem,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
return dropout(x)
# feed forward block
def _ff_block(self, x):
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
return self.dropout3(x)
class DecoderLayerStack(nn.Module):
def __init__(self, stack):
super().__init__()
self.stack = stack
def forward(self, x, cond, t, tgt_mask=None, memory2=None):
for layer in self.stack:
x = layer(x, cond, t, tgt_mask=tgt_mask, memory2=memory2)
return x
class PositionalEncoding(nn.Module):
def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 1024):
super().__init__()
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x: torch.Tensor):
"""
:param x: B x T x d_model tensor
:return: B x T x d_model tensor
"""
x = x + self.pe[None, : x.shape[1], :]
x = self.dropout(x)
return x
class TimestepEncoding(nn.Module):
def __init__(self, embedding_dim: int):
super().__init__()
# Fourier embedding
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim) * -emb)
self.register_buffer("emb", emb)
# encoding
self.encoding = nn.Sequential(
nn.Linear(embedding_dim, 4 * embedding_dim),
nn.Mish(),
nn.Linear(4 * embedding_dim, embedding_dim),
)
def forward(self, t: torch.Tensor):
"""
:param t: B-dimensional tensor containing timesteps in range [0, 1]
:return: B x embedding_dim tensor containing timestep encodings
"""
x = t[:, None] * self.emb[None, :]
x = torch.cat([torch.sin(x), torch.cos(x)], dim=-1)
x = self.encoding(x)
return x
class FiLM(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.dim = dim
self.film = nn.Sequential(nn.Mish(), nn.Linear(dim, dim * 2))
def forward(self, x: torch.Tensor, cond: torch.Tensor):
"""
:param x: ... x dim tensor
:param cond: ... x dim tensor
:return: ... x dim tensor as scale(cond) * x + bias(cond)
"""
cond = self.film(cond)
scale, bias = torch.chunk(cond, chunks=2, dim=-1)
x = (scale + 1) * x + bias
return x
class FeedforwardBlock(nn.Module):
def __init__(self, d_model: int, d_feedforward: int = 1024, dropout: float = 0.1):
super().__init__()
self.ff = nn.Sequential(
nn.Linear(d_model, d_feedforward),
nn.ReLU(),
nn.Dropout(p=dropout),
nn.Linear(d_feedforward, d_model),
nn.Dropout(p=dropout),
)
def forward(self, x: torch.Tensor):
"""
:param x: ... x d_model tensor
:return: ... x d_model tensor
"""
return self.ff(x)
class SelfAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
super().__init__()
self.self_attn = nn.MultiheadAttention(
d_model, num_heads, dropout=dropout, batch_first=True
)
self.dropout = nn.Dropout(p=dropout)
def forward(
self,
x: torch.Tensor,
attn_mask: torch.Tensor = None,
key_padding_mask: torch.Tensor = None,
):
"""
:param x: B x T x d_model input tensor
:param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length
for a float mask: values will be added to attention weight
for a binary mask: True indicates that the element is not allowed to attend
:param key_padding_mask: B x S mask
for a float mask: values will be added directly to the corresponding key values
for a binary mask: True indicates that the corresponding key value will be ignored
:return: B x T x d_model output tensor
"""
x = self.self_attn(
x,
x,
x,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
x = self.dropout(x)
return x
class CrossAttention(nn.Module):
def __init__(self, d_model: int, d_cond: int, num_heads: int, dropout: float = 0.1):
super().__init__()
self.cross_attn = nn.MultiheadAttention(
d_model,
num_heads,
dropout=dropout,
batch_first=True,
kdim=d_cond,
vdim=d_cond,
)
self.dropout = nn.Dropout(p=dropout)
def forward(
self,
x: torch.Tensor,
cond: torch.Tensor,
attn_mask: torch.Tensor = None,
key_padding_mask: torch.Tensor = None,
):
"""
:param x: B x T_target x d_model input tensor
:param cond: B x T_cond x d_cond condition tensor
:param attn_mask: B * num_heads x L x S mask with L=target sequence length, S=source sequence length
for a float mask: values will be added to attention weight
for a binary mask: True indicates that the element is not allowed to attend
:param key_padding_mask: B x S mask
for a float mask: values will be added directly to the corresponding key values
for a binary mask: True indicates that the corresponding key value will be ignored
:return: B x T x d_model output tensor
"""
x = self.cross_attn(
x,
cond,
cond,
attn_mask=attn_mask,
key_padding_mask=key_padding_mask,
need_weights=False,
)[0]
x = self.dropout(x)
return x
class TransformerEncoderLayer(nn.Module):
def __init__(
self,
d_model: int,
num_heads: int,
d_feedforward: int = 1024,
dropout: float = 0.1,
):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.self_attn = SelfAttention(d_model, num_heads, dropout)
self.norm2 = nn.LayerNorm(d_model)
self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout)
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor = None,
key_padding_mask: torch.Tensor = None,
):
x = x + self.self_attn(self.norm1(x), mask, key_padding_mask)
x = x + self.feedforward(self.norm2(x))
return x
class TransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model: int,
d_cond: int,
num_heads: int,
d_feedforward: int = 1024,
dropout: float = 0.1,
):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.self_attn = SelfAttention(d_model, num_heads, dropout)
self.norm2 = nn.LayerNorm(d_model)
self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout)
self.norm3 = nn.LayerNorm(d_model)
self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout)
def forward(
self,
x: torch.Tensor,
cross_cond: torch.Tensor,
target_mask: torch.Tensor = None,
target_key_padding_mask: torch.Tensor = None,
cross_cond_mask: torch.Tensor = None,
cross_cond_key_padding_mask: torch.Tensor = None,
):
"""
:param x: B x T x d_model tensor
:param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers
:return: B x T x d_model tensor
"""
x = x + self.self_attn(self.norm1(x), target_mask, target_key_padding_mask)
x = x + self.cross_attn(
self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask
)
x = x + self.feedforward(self.norm3(x))
return x
class FilmTransformerDecoderLayer(nn.Module):
def __init__(
self,
d_model: int,
d_cond: int,
num_heads: int,
d_feedforward: int = 1024,
dropout: float = 0.1,
):
super().__init__()
self.norm1 = nn.LayerNorm(d_model)
self.self_attn = SelfAttention(d_model, num_heads, dropout)
self.film1 = FiLM(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.cross_attn = CrossAttention(d_model, d_cond, num_heads, dropout)
self.film2 = FiLM(d_model)
self.norm3 = nn.LayerNorm(d_model)
self.feedforward = FeedforwardBlock(d_model, d_feedforward, dropout)
self.film3 = FiLM(d_model)
def forward(
self,
x: torch.Tensor,
cross_cond: torch.Tensor,
film_cond: torch.Tensor,
target_mask: torch.Tensor = None,
target_key_padding_mask: torch.Tensor = None,
cross_cond_mask: torch.Tensor = None,
cross_cond_key_padding_mask: torch.Tensor = None,
):
"""
:param x: B x T x d_model tensor
:param cross_cond: B x T x d_cond tensor containing the conditioning input to cross attention layers
:param film_cond: B x [1 or T] x film_cond tensor containing the conditioning input to FiLM layers
:return: B x T x d_model tensor
"""
x1 = self.self_attn(self.norm1(x), target_mask, target_key_padding_mask)
x = x + self.film1(x1, film_cond)
x2 = self.cross_attn(
self.norm2(x), cross_cond, cross_cond_mask, cross_cond_key_padding_mask
)
x = x + self.film2(x2, film_cond)
x3 = self.feedforward(self.norm3(x))
x = x + self.film3(x3, film_cond)
return x
class RegressionTransformer(nn.Module):
def __init__(
self,
transformer_encoder_layers: int = 2,
transformer_decoder_layers: int = 4,
d_model: int = 512,
d_cond: int = 512,
num_heads: int = 4,
d_feedforward: int = 1024,
dropout: float = 0.1,
causal: bool = False,
):
super().__init__()
self.causal = causal
self.cond_positional_encoding = PositionalEncoding(d_cond, dropout)
self.target_positional_encoding = PositionalEncoding(d_model, dropout)
self.transformer_encoder = nn.ModuleList(
[
TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout)
for _ in range(transformer_encoder_layers)
]
)
self.transformer_decoder = nn.ModuleList(
[
TransformerDecoderLayer(
d_model, d_cond, num_heads, d_feedforward, dropout
)
for _ in range(transformer_decoder_layers)
]
)
def forward(self, x: torch.Tensor, cond: torch.Tensor):
"""
:param x: B x T x d_model input tensor
:param cond: B x T x d_cond conditional tensor
:return: B x T x d_model output tensor
"""
x = self.target_positional_encoding(x)
cond = self.cond_positional_encoding(cond)
if self.causal:
encoder_mask = generate_causal_mask(
cond.shape[1], cond.shape[1], device=cond.device
)
decoder_self_attn_mask = generate_causal_mask(
x.shape[1], x.shape[1], device=x.device
)
decoder_cross_attn_mask = generate_causal_mask(
cond.shape[1], x.shape[1], device=x.device
)
else:
encoder_mask = None
decoder_self_attn_mask = None
decoder_cross_attn_mask = None
for encoder_layer in self.transformer_encoder:
cond = encoder_layer(cond, mask=encoder_mask)
for decoder_layer in self.transformer_decoder:
x = decoder_layer(
x,
cond,
target_mask=decoder_self_attn_mask,
cross_cond_mask=decoder_cross_attn_mask,
)
return x
class DiffusionTransformer(nn.Module):
def __init__(
self,
transformer_encoder_layers: int = 2,
transformer_decoder_layers: int = 4,
d_model: int = 512,
d_cond: int = 512,
num_heads: int = 4,
d_feedforward: int = 1024,
dropout: float = 0.1,
causal: bool = False,
):
super().__init__()
self.causal = causal
self.timestep_encoder = TimestepEncoding(d_model)
self.cond_positional_encoding = PositionalEncoding(d_cond, dropout)
self.target_positional_encoding = PositionalEncoding(d_model, dropout)
self.transformer_encoder = nn.ModuleList(
[
TransformerEncoderLayer(d_cond, num_heads, d_feedforward, dropout)
for _ in range(transformer_encoder_layers)
]
)
self.transformer_decoder = nn.ModuleList(
[
FilmTransformerDecoderLayer(
d_model, d_cond, num_heads, d_feedforward, dropout
)
for _ in range(transformer_decoder_layers)
]
)
def forward(self, x: torch.Tensor, cond: torch.Tensor, t: torch.Tensor):
"""
:param x: B x T x d_model input tensor
:param cond: B x T x d_cond conditional tensor
:param t: B-dimensional tensor containing diffusion timesteps in range [0, 1]
:return: B x T x d_model output tensor
"""
t = self.timestep_encoder(t).unsqueeze(1) # B x 1 x d_model
x = self.target_positional_encoding(x)
cond = self.cond_positional_encoding(cond)
if self.causal:
encoder_mask = generate_causal_mask(
cond.shape[1], cond.shape[1], device=cond.device
)
decoder_self_attn_mask = generate_causal_mask(
x.shape[1], x.shape[1], device=x.device
)
decoder_cross_attn_mask = generate_causal_mask(
cond.shape[1], x.shape[1], device=x.device
)
else:
encoder_mask = None
decoder_self_attn_mask = None
decoder_cross_attn_mask = None
for encoder_layer in self.transformer_encoder:
cond = encoder_layer(cond, mask=encoder_mask)
for decoder_layer in self.transformer_decoder:
x = decoder_layer(
x,
cond,
t,
target_mask=decoder_self_attn_mask,
cross_cond_mask=decoder_cross_attn_mask,
)
return x