Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func | |
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): | |
# x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2 | |
# the last shape is "self.hidden_dim / 2" because we convert to complex | |
assert x.ndim == 4 | |
assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]), \ | |
f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}' | |
# reshape freq cis to match and apply pointwise multiply | |
# new shape: bsz, seq_len, 1, self.head_hidden_dim / 2 | |
shape = [x.shape[0], x.shape[1], 1, x.shape[-1]] | |
return freqs_cis.view(*shape) | |
def apply_rotary_emb( | |
xq: torch.Tensor, | |
xk: torch.Tensor, | |
freqs_cis: torch.Tensor, | |
): | |
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) | |
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) | |
freqs_cis = reshape_for_broadcast(freqs_cis, xq_) | |
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3) | |
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3) | |
return xq_out.type_as(xq), xk_out.type_as(xk) | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
num_heads: int = 8, | |
qkv_bias: bool = False, | |
qk_norm: bool = False, | |
attn_drop: float = 0., | |
proj_drop: float = 0., | |
norm_layer: nn.Module = nn.LayerNorm, | |
flash_attention: bool = True | |
) -> None: | |
super().__init__() | |
assert dim % num_heads == 0, 'dim should be divisible by num_heads' | |
self.num_heads = num_heads | |
self.head_dim = dim // num_heads | |
self.scale = self.head_dim ** -0.5 | |
self.fused_attn = flash_attention | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.qk_norm = qk_norm | |
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, rotary_pos_emb=None, incremental_state=None, nopadding=True) -> torch.Tensor: | |
B, N, C = x.shape | |
if self.fused_attn: | |
if nopadding: | |
qkv = self.qkv(x) | |
qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim) | |
q, k, v = qkv.split([self.num_heads] * 3, dim=1) | |
q, k = self.q_norm(q), self.k_norm(k) | |
q = q.view(B, N, self.num_heads, self.head_dim) | |
k = k.view(B, N, self.num_heads, self.head_dim) | |
v = v.view(B, N, self.num_heads, self.head_dim) | |
if rotary_pos_emb is not None: | |
q, k = apply_rotary_emb(q, k, rotary_pos_emb) | |
if incremental_state is not None: | |
if "prev_k" in incremental_state: | |
prev_k = incremental_state["prev_k"] | |
k = torch.cat([prev_k, k], dim=1) | |
if "cur_k" not in incremental_state: | |
incremental_state["cur_k"] = {} | |
incremental_state["cur_k"] = k | |
if "prev_v" in incremental_state: | |
prev_v = incremental_state["prev_v"] | |
v = torch.cat([prev_v, v], dim=1) | |
if "cur_v" not in incremental_state: | |
incremental_state["cur_v"] = {} | |
incremental_state["cur_v"] = v | |
q = q.view(B * N, self.num_heads, self.head_dim) | |
k = k.view(-1, self.num_heads, self.head_dim) | |
v = v.view(-1, self.num_heads, self.head_dim) | |
x = flash_attn_varlen_func( | |
q=q, | |
k=k, | |
v=v, | |
cu_seqlens_q=cu_seqlens, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max_seqlen, | |
max_seqlen_k=max_seqlen_k, | |
dropout_p=self.attn_drop.p if self.training else 0., | |
) | |
else: | |
if incremental_state is not None: | |
raise NotImplementedError("It is designed for batching inference. AR-chunk is not supported currently.") | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) | |
if self.qk_norm: | |
q, k, v = qkv.unbind(2) | |
q, k = self.q_norm(q), self.k_norm(k) | |
# re-bind | |
qkv = torch.stack((q, k, v), dim=2) | |
# pack qkv with seq_len | |
qkv_collect = [] | |
for i in range(qkv.shape[0]): | |
qkv_collect.append( | |
qkv[i, :seq_len[i], :, :, :] | |
) | |
qkv = torch.cat(qkv_collect, dim=0) | |
x = flash_attn_varlen_qkvpacked_func(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.attn_drop.p if self.training else 0.) | |
# unpack and pad 0 | |
x_collect = [] | |
for i in range(B): | |
x_collect.append( | |
x[cu_seqlens[i]:cu_seqlens[i+1], :, :] | |
) | |
x = torch.nn.utils.rnn.pad_sequence(x_collect, batch_first=True, padding_value=0) | |
else: | |
q = q * self.scale | |
attn = q @ k.transpose(-2, -1) | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = attn @ v | |
x = x.transpose(1, 2) | |
x = x.reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
def modulate(x, shift, scale): | |
return x * (1 + scale) + shift | |
class FinalLayer(nn.Module): | |
""" | |
The final layer of DiT. | |
""" | |
def __init__(self, hidden_size, out_channels): | |
super().__init__() | |
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.linear = nn.Linear(hidden_size, out_channels, bias=True) | |
self.adaLN_modulation = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(hidden_size, 2 * hidden_size, bias=True) | |
) | |
def forward(self, x, c): | |
shift, scale = self.adaLN_modulation(c).chunk(2, dim=2) | |
x = modulate(self.norm_final(x), shift, scale) | |
x = self.linear(x) | |
return x | |
class DiTBlock(nn.Module): | |
""" | |
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. | |
""" | |
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="conv1d_conv1d", ffn_gated_glu=True, ffn_act_layer="gelu", ffn_conv_kernel_size=5, **block_kwargs): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs) | |
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) | |
if ffn_type == "vanilla_mlp": | |
from timm.models.vision_transformer import Mlp | |
mlp_hidden_dim = int(hidden_size * mlp_ratio) | |
approx_gelu = lambda: nn.GELU(approximate="tanh") | |
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0) | |
else: | |
raise NotImplementedError(f"FFN type {ffn_type} is not implemented") | |
self.adaLN_modulation = nn.Sequential( | |
nn.SiLU(), | |
nn.Linear(hidden_size, 6 * hidden_size, bias=True) | |
) | |
def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, rotary_pos_emb=None, incremental_state=None, nopadding=True): | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2) | |
x_ = modulate(self.norm1(x), shift_msa, scale_msa) | |
if incremental_state is not None: | |
if "attn_kvcache" not in incremental_state: | |
incremental_state["attn_kvcache"] = {} | |
inc_attn = incremental_state["attn_kvcache"] | |
else: | |
inc_attn = None | |
x_ = self.attn(x_, seq_len=seq_len, cu_seqlens=cu_seqlens, max_seqlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=cu_maxlen_k, rotary_pos_emb=rotary_pos_emb, incremental_state=inc_attn, nopadding=nopadding) | |
if not nopadding: | |
x_ = x_ * mask[:, :, None] | |
x = x + gate_msa * x_ | |
x_ = modulate(self.norm2(x), shift_mlp, scale_mlp) | |
x_ = self.mlp(x_) | |
if not nopadding: | |
x_ = x_ * mask[:, :, None] | |
x = x + gate_mlp * x_ | |
return x | |