Spaces:
Paused
Paused
# Modified from Flux | |
# | |
# Copyright 2024 Black Forest Labs | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
import math # noqa: I001 | |
from dataclasses import dataclass | |
from functools import partial | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange | |
# from liger_kernel.ops.rms_norm import LigerRMSNormFunction | |
from torch import Tensor, nn | |
try: | |
import flash_attn | |
from flash_attn.flash_attn_interface import ( | |
_flash_attn_forward, | |
flash_attn_varlen_func, | |
) | |
except ImportError: | |
flash_attn = None | |
flash_attn_varlen_func = None | |
_flash_attn_forward = None | |
MEMORY_LAYOUT = { | |
"flash": ( | |
lambda x: x.view(x.shape[0] * x.shape[1], *x.shape[2:]), | |
lambda x: x, | |
), | |
"torch": ( | |
lambda x: x.transpose(1, 2), | |
lambda x: x.transpose(1, 2), | |
), | |
"vanilla": ( | |
lambda x: x.transpose(1, 2), | |
lambda x: x.transpose(1, 2), | |
), | |
} | |
def attention( | |
q, | |
k, | |
v, | |
mode="torch", | |
drop_rate=0, | |
attn_mask=None, | |
causal=False, | |
cu_seqlens_q=None, | |
cu_seqlens_kv=None, | |
max_seqlen_q=None, | |
max_seqlen_kv=None, | |
batch_size=1, | |
): | |
""" | |
Perform QKV self attention. | |
Args: | |
q (torch.Tensor): Query tensor with shape [b, s, a, d], where a is the number of heads. | |
k (torch.Tensor): Key tensor with shape [b, s1, a, d] | |
v (torch.Tensor): Value tensor with shape [b, s1, a, d] | |
mode (str): Attention mode. Choose from 'self_flash', 'cross_flash', 'torch', and 'vanilla'. | |
drop_rate (float): Dropout rate in attention map. (default: 0) | |
attn_mask (torch.Tensor): Attention mask with shape [b, s1] (cross_attn), or [b, a, s, s1] (torch or vanilla). | |
(default: None) | |
causal (bool): Whether to use causal attention. (default: False) | |
cu_seqlens_q (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, | |
used to index into q. | |
cu_seqlens_kv (torch.Tensor): dtype torch.int32. The cumulative sequence lengths of the sequences in the batch, | |
used to index into kv. | |
max_seqlen_q (int): The maximum sequence length in the batch of q. | |
max_seqlen_kv (int): The maximum sequence length in the batch of k and v. | |
Returns: | |
torch.Tensor: Output tensor after self attention with shape [b, s, ad] | |
""" | |
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode] | |
q = pre_attn_layout(q) | |
k = pre_attn_layout(k) | |
v = pre_attn_layout(v) | |
if mode == "torch": | |
if attn_mask is not None and attn_mask.dtype != torch.bool: | |
attn_mask = attn_mask.to(q.dtype) | |
x = F.scaled_dot_product_attention( | |
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal | |
) | |
elif mode == "flash": | |
assert flash_attn_varlen_func is not None | |
x: torch.Tensor = flash_attn_varlen_func( | |
q, | |
k, | |
v, | |
cu_seqlens_q, | |
cu_seqlens_kv, | |
max_seqlen_q, | |
max_seqlen_kv, | |
) # type: ignore | |
# x with shape [(bxs), a, d] | |
x = x.view(batch_size, max_seqlen_q, x.shape[-2], x.shape[-1]) # type: ignore # reshape x to [b, s, a, d] | |
elif mode == "vanilla": | |
scale_factor = 1 / math.sqrt(q.size(-1)) | |
b, a, s, _ = q.shape | |
s1 = k.size(2) | |
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device) | |
if causal: | |
# Only applied to self attention | |
assert attn_mask is None, ( | |
"Causal mask and attn_mask cannot be used together" | |
) | |
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril( | |
diagonal=0 | |
) | |
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | |
attn_bias.to(q.dtype) | |
if attn_mask is not None: | |
if attn_mask.dtype == torch.bool: | |
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | |
else: | |
attn_bias += attn_mask | |
# TODO: Maybe force q and k to be float32 to avoid numerical overflow | |
attn = (q @ k.transpose(-2, -1)) * scale_factor | |
attn += attn_bias | |
attn = attn.softmax(dim=-1) | |
attn = torch.dropout(attn, p=drop_rate, train=True) | |
x = attn @ v | |
else: | |
raise NotImplementedError(f"Unsupported attention mode: {mode}") | |
x = post_attn_layout(x) | |
b, s, a, d = x.shape | |
out = x.reshape(b, s, -1) | |
return out | |
def apply_gate(x, gate=None, tanh=False): | |
"""AI is creating summary for apply_gate | |
Args: | |
x (torch.Tensor): input tensor. | |
gate (torch.Tensor, optional): gate tensor. Defaults to None. | |
tanh (bool, optional): whether to use tanh function. Defaults to False. | |
Returns: | |
torch.Tensor: the output tensor after apply gate. | |
""" | |
if gate is None: | |
return x | |
if tanh: | |
return x * gate.unsqueeze(1).tanh() | |
else: | |
return x * gate.unsqueeze(1) | |
class MLP(nn.Module): | |
"""MLP as used in Vision Transformer, MLP-Mixer and related networks""" | |
def __init__( | |
self, | |
in_channels, | |
hidden_channels=None, | |
out_features=None, | |
act_layer=nn.GELU, | |
norm_layer=None, | |
bias=True, | |
drop=0.0, | |
use_conv=False, | |
device=None, | |
dtype=None, | |
): | |
super().__init__() | |
out_features = out_features or in_channels | |
hidden_channels = hidden_channels or in_channels | |
bias = (bias, bias) | |
drop_probs = (drop, drop) | |
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear | |
self.fc1 = linear_layer( | |
in_channels, hidden_channels, bias=bias[0], device=device, dtype=dtype | |
) | |
self.act = act_layer() | |
self.drop1 = nn.Dropout(drop_probs[0]) | |
self.norm = ( | |
norm_layer(hidden_channels, device=device, dtype=dtype) | |
if norm_layer is not None | |
else nn.Identity() | |
) | |
self.fc2 = linear_layer( | |
hidden_channels, out_features, bias=bias[1], device=device, dtype=dtype | |
) | |
self.drop2 = nn.Dropout(drop_probs[1]) | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop1(x) | |
x = self.norm(x) | |
x = self.fc2(x) | |
x = self.drop2(x) | |
return x | |
class TextProjection(nn.Module): | |
""" | |
Projects text embeddings. Also handles dropout for classifier-free guidance. | |
Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py | |
""" | |
def __init__(self, in_channels, hidden_size, act_layer, dtype=None, device=None): | |
factory_kwargs = {"dtype": dtype, "device": device} | |
super().__init__() | |
self.linear_1 = nn.Linear( | |
in_features=in_channels, | |
out_features=hidden_size, | |
bias=True, | |
**factory_kwargs, | |
) | |
self.act_1 = act_layer() | |
self.linear_2 = nn.Linear( | |
in_features=hidden_size, | |
out_features=hidden_size, | |
bias=True, | |
**factory_kwargs, | |
) | |
def forward(self, caption): | |
hidden_states = self.linear_1(caption) | |
hidden_states = self.act_1(hidden_states) | |
hidden_states = self.linear_2(hidden_states) | |
return hidden_states | |
class TimestepEmbedder(nn.Module): | |
""" | |
Embeds scalar timesteps into vector representations. | |
""" | |
def __init__( | |
self, | |
hidden_size, | |
act_layer, | |
frequency_embedding_size=256, | |
max_period=10000, | |
out_size=None, | |
dtype=None, | |
device=None, | |
): | |
factory_kwargs = {"dtype": dtype, "device": device} | |
super().__init__() | |
self.frequency_embedding_size = frequency_embedding_size | |
self.max_period = max_period | |
if out_size is None: | |
out_size = hidden_size | |
self.mlp = nn.Sequential( | |
nn.Linear( | |
frequency_embedding_size, hidden_size, bias=True, **factory_kwargs | |
), | |
act_layer(), | |
nn.Linear(hidden_size, out_size, bias=True, **factory_kwargs), | |
) | |
nn.init.normal_(self.mlp[0].weight, std=0.02) # type: ignore | |
nn.init.normal_(self.mlp[2].weight, std=0.02) # type: ignore | |
def timestep_embedding(t, dim, max_period=10000): | |
""" | |
Create sinusoidal timestep embeddings. | |
Args: | |
t (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. | |
dim (int): the dimension of the output. | |
max_period (int): controls the minimum frequency of the embeddings. | |
Returns: | |
embedding (torch.Tensor): An (N, D) Tensor of positional embeddings. | |
.. ref_link: https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py | |
""" | |
half = dim // 2 | |
freqs = torch.exp( | |
-math.log(max_period) | |
* torch.arange(start=0, end=half, dtype=torch.float32) | |
/ half | |
).to(device=t.device) | |
args = t[:, None].float() * freqs[None] | |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
if dim % 2: | |
embedding = torch.cat( | |
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1 | |
) | |
return embedding | |
def forward(self, t): | |
t_freq = self.timestep_embedding( | |
t, self.frequency_embedding_size, self.max_period | |
).type(self.mlp[0].weight.dtype) # type: ignore | |
t_emb = self.mlp(t_freq) | |
return t_emb | |
class EmbedND(nn.Module): | |
def __init__(self, dim: int, theta: int, axes_dim: list[int]): | |
super().__init__() | |
self.dim = dim | |
self.theta = theta | |
self.axes_dim = axes_dim | |
def forward(self, ids: Tensor) -> Tensor: | |
n_axes = ids.shape[-1] | |
emb = torch.cat( | |
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], | |
dim=-3, | |
) | |
return emb.unsqueeze(1) | |
class MLPEmbedder(nn.Module): | |
def __init__(self, in_dim: int, hidden_dim: int): | |
super().__init__() | |
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) | |
self.silu = nn.SiLU() | |
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) | |
def forward(self, x: Tensor) -> Tensor: | |
return self.out_layer(self.silu(self.in_layer(x))) | |
def rope(pos, dim: int, theta: int): | |
assert dim % 2 == 0 | |
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim | |
omega = 1.0 / (theta**scale) | |
out = torch.einsum("...n,d->...nd", pos, omega) | |
out = torch.stack( | |
[torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1 | |
) | |
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) | |
return out.float() | |
def attention_after_rope(q, k, v, pe): | |
q, k = apply_rope(q, k, pe) | |
from .attention import attention | |
x = attention(q, k, v, mode="torch") | |
return x | |
def apply_rope(xq, xk, freqs_cis): | |
# 将 num_heads 和 seq_len 的维度交换回原函数的处理顺序 | |
xq = xq.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] | |
xk = xk.transpose(1, 2) | |
# 将 head_dim 拆分为复数部分(实部和虚部) | |
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) | |
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) | |
# 应用旋转位置编码(复数乘法) | |
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] | |
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] | |
# 恢复张量形状并转置回目标维度顺序 | |
xq_out = xq_out.reshape(*xq.shape).type_as(xq).transpose(1, 2) | |
xk_out = xk_out.reshape(*xk.shape).type_as(xk).transpose(1, 2) | |
return xq_out, xk_out | |
def scale_add_residual( | |
x: torch.Tensor, scale: torch.Tensor, residual: torch.Tensor | |
) -> torch.Tensor: | |
return x * scale + residual | |
def layernorm_and_scale_shift( | |
x: torch.Tensor, scale: torch.Tensor, shift: torch.Tensor | |
) -> torch.Tensor: | |
return torch.nn.functional.layer_norm(x, (x.size(-1),)) * (scale + 1) + shift | |
class SelfAttention(nn.Module): | |
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.norm = QKNorm(head_dim) | |
self.proj = nn.Linear(dim, dim) | |
def forward(self, x: Tensor, pe: Tensor) -> Tensor: | |
qkv = self.qkv(x) | |
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) | |
q, k = self.norm(q, k, v) | |
x = attention_after_rope(q, k, v, pe=pe) | |
x = self.proj(x) | |
return x | |
class ModulationOut: | |
shift: Tensor | |
scale: Tensor | |
gate: Tensor | |
class RMSNorm(torch.nn.Module): | |
def __init__(self, dim: int): | |
super().__init__() | |
self.scale = nn.Parameter(torch.ones(dim)) | |
# @staticmethod | |
# def rms_norm_fast(x, weight, eps): | |
# return LigerRMSNormFunction.apply( | |
# x, | |
# weight, | |
# eps, | |
# 0.0, | |
# "gemma", | |
# True, | |
# ) | |
def rms_norm(x, weight, eps): | |
x_dtype = x.dtype | |
x = x.float() | |
rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + eps) | |
return (x * rrms).to(dtype=x_dtype) * weight | |
def forward(self, x: Tensor): | |
# return self.rms_norm_fast(x, self.scale, 1e-6) | |
return self.rms_norm(x, self.scale, 1e-6) | |
class QKNorm(torch.nn.Module): | |
def __init__(self, dim: int): | |
super().__init__() | |
self.query_norm = RMSNorm(dim) | |
self.key_norm = RMSNorm(dim) | |
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: | |
q = self.query_norm(q) | |
k = self.key_norm(k) | |
return q.to(v), k.to(v) | |
class Modulation(nn.Module): | |
def __init__(self, dim: int, double: bool): | |
super().__init__() | |
self.is_double = double | |
self.multiplier = 6 if double else 3 | |
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) | |
def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut | None]: | |
out = self.lin(nn.functional.silu(vec))[:, None, :].chunk( | |
self.multiplier, dim=-1 | |
) | |
return ( | |
ModulationOut(*out[:3]), | |
ModulationOut(*out[3:]) if self.is_double else None, | |
) | |
class DoubleStreamBlock(nn.Module): | |
def __init__( | |
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False | |
): | |
super().__init__() | |
mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
self.num_heads = num_heads | |
self.hidden_size = hidden_size | |
self.img_mod = Modulation(hidden_size, double=True) | |
self.img_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.img_attn = SelfAttention( | |
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias | |
) | |
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.img_mlp = nn.Sequential( | |
nn.Linear(hidden_size, mlp_hidden_dim, bias=True), | |
nn.GELU(approximate="tanh"), | |
nn.Linear(mlp_hidden_dim, hidden_size, bias=True), | |
) | |
self.txt_mod = Modulation(hidden_size, double=True) | |
self.txt_norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.txt_attn = SelfAttention( | |
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias | |
) | |
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.txt_mlp = nn.Sequential( | |
nn.Linear(hidden_size, mlp_hidden_dim, bias=True), | |
nn.GELU(approximate="tanh"), | |
nn.Linear(mlp_hidden_dim, hidden_size, bias=True), | |
) | |
def forward( | |
self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor | |
) -> tuple[Tensor, Tensor]: | |
img_mod1, img_mod2 = self.img_mod(vec) | |
txt_mod1, txt_mod2 = self.txt_mod(vec) | |
# prepare image for attention | |
img_modulated = self.img_norm1(img) | |
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift | |
img_qkv = self.img_attn.qkv(img_modulated) | |
img_q, img_k, img_v = rearrange( | |
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads | |
) | |
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) | |
# prepare txt for attention | |
txt_modulated = self.txt_norm1(txt) | |
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift | |
txt_qkv = self.txt_attn.qkv(txt_modulated) | |
txt_q, txt_k, txt_v = rearrange( | |
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads | |
) | |
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) | |
# run actual attention | |
q = torch.cat((txt_q, img_q), dim=1) | |
k = torch.cat((txt_k, img_k), dim=1) | |
v = torch.cat((txt_v, img_v), dim=1) | |
attn = attention_after_rope(q, k, v, pe=pe) | |
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] | |
# calculate the img bloks | |
img = img + img_mod1.gate * self.img_attn.proj(img_attn) | |
img_mlp = self.img_mlp( | |
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift | |
) | |
img = scale_add_residual(img_mlp, img_mod2.gate, img) | |
# calculate the txt bloks | |
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) | |
txt_mlp = self.txt_mlp( | |
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift | |
) | |
txt = scale_add_residual(txt_mlp, txt_mod2.gate, txt) | |
return img, txt | |
class SingleStreamBlock(nn.Module): | |
""" | |
A DiT block with parallel linear layers as described in | |
https://arxiv.org/abs/2302.05442 and adapted modulation interface. | |
""" | |
def __init__( | |
self, | |
hidden_size: int, | |
num_heads: int, | |
mlp_ratio: float = 4.0, | |
qk_scale: float | None = None, | |
): | |
super().__init__() | |
self.hidden_dim = hidden_size | |
self.num_heads = num_heads | |
head_dim = hidden_size // num_heads | |
self.scale = qk_scale or head_dim**-0.5 | |
self.mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
# qkv and mlp_in | |
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) | |
# proj and mlp_out | |
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) | |
self.norm = QKNorm(head_dim) | |
self.hidden_size = hidden_size | |
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.mlp_act = nn.GELU(approximate="tanh") | |
self.modulation = Modulation(hidden_size, double=False) | |
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor: | |
mod, _ = self.modulation(vec) | |
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift | |
qkv, mlp = torch.split( | |
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 | |
) | |
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads) | |
q, k = self.norm(q, k, v) | |
# compute attention | |
attn = attention_after_rope(q, k, v, pe=pe) | |
# compute activation in mlp stream, cat again and run second linear layer | |
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) | |
return scale_add_residual(output, mod.gate, x) | |
class LastLayer(nn.Module): | |
def __init__(self, hidden_size: int, patch_size: int, out_channels: int): | |
super().__init__() | |
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.linear = nn.Linear( | |
hidden_size, patch_size * patch_size * out_channels, bias=True | |
) | |
self.adaLN_modulation = nn.Sequential( | |
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) | |
) | |
def forward(self, x: Tensor, vec: Tensor) -> Tensor: | |
shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) | |
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] | |
x = self.linear(x) | |
return x | |