jzq11111's picture
Upload folder using huggingface_hub
a3e05e8 verified
raw
history blame contribute delete
9.13 kB
import torch
import torch.nn as nn
import torch
import torch.nn as nn
import torch.nn.functional as F
from flash_attn import flash_attn_varlen_func, flash_attn_varlen_qkvpacked_func
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
# x shape: bsz, seqlen, self.n_local_heads, self.head_hidden_dim / 2
# the last shape is "self.hidden_dim / 2" because we convert to complex
assert x.ndim == 4
assert freqs_cis.shape == (x.shape[0], x.shape[1], x.shape[-1]), \
f'x shape: {x.shape}, freqs_cis shape: {freqs_cis.shape}'
# reshape freq cis to match and apply pointwise multiply
# new shape: bsz, seq_len, 1, self.head_hidden_dim / 2
shape = [x.shape[0], x.shape[1], 1, x.shape[-1]]
return freqs_cis.view(*shape)
def apply_rotary_emb(
xq: torch.Tensor,
xk: torch.Tensor,
freqs_cis: torch.Tensor,
):
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qk_norm: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
norm_layer: nn.Module = nn.LayerNorm,
flash_attention: bool = True
) -> None:
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.scale = self.head_dim ** -0.5
self.fused_attn = flash_attention
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.qk_norm = qk_norm
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, seq_len, cu_seqlens, max_seqlen, cu_seqlens_k, max_seqlen_k, rotary_pos_emb=None, incremental_state=None, nopadding=True) -> torch.Tensor:
B, N, C = x.shape
if self.fused_attn:
if nopadding:
qkv = self.qkv(x)
qkv = qkv.view(B * N, self.num_heads * 3, self.head_dim)
q, k, v = qkv.split([self.num_heads] * 3, dim=1)
q, k = self.q_norm(q), self.k_norm(k)
q = q.view(B, N, self.num_heads, self.head_dim)
k = k.view(B, N, self.num_heads, self.head_dim)
v = v.view(B, N, self.num_heads, self.head_dim)
if rotary_pos_emb is not None:
q, k = apply_rotary_emb(q, k, rotary_pos_emb)
if incremental_state is not None:
if "prev_k" in incremental_state:
prev_k = incremental_state["prev_k"]
k = torch.cat([prev_k, k], dim=1)
if "cur_k" not in incremental_state:
incremental_state["cur_k"] = {}
incremental_state["cur_k"] = k
if "prev_v" in incremental_state:
prev_v = incremental_state["prev_v"]
v = torch.cat([prev_v, v], dim=1)
if "cur_v" not in incremental_state:
incremental_state["cur_v"] = {}
incremental_state["cur_v"] = v
q = q.view(B * N, self.num_heads, self.head_dim)
k = k.view(-1, self.num_heads, self.head_dim)
v = v.view(-1, self.num_heads, self.head_dim)
x = flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=cu_seqlens,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen,
max_seqlen_k=max_seqlen_k,
dropout_p=self.attn_drop.p if self.training else 0.,
)
else:
if incremental_state is not None:
raise NotImplementedError("It is designed for batching inference. AR-chunk is not supported currently.")
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim)
if self.qk_norm:
q, k, v = qkv.unbind(2)
q, k = self.q_norm(q), self.k_norm(k)
# re-bind
qkv = torch.stack((q, k, v), dim=2)
# pack qkv with seq_len
qkv_collect = []
for i in range(qkv.shape[0]):
qkv_collect.append(
qkv[i, :seq_len[i], :, :, :]
)
qkv = torch.cat(qkv_collect, dim=0)
x = flash_attn_varlen_qkvpacked_func(qkv=qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, dropout_p=self.attn_drop.p if self.training else 0.)
# unpack and pad 0
x_collect = []
for i in range(B):
x_collect.append(
x[cu_seqlens[i]:cu_seqlens[i+1], :, :]
)
x = torch.nn.utils.rnn.pad_sequence(x_collect, batch_first=True, padding_value=0)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
x = x.transpose(1, 2)
x = x.reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class FinalLayer(nn.Module):
"""
The final layer of DiT.
"""
def __init__(self, hidden_size, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.linear = nn.Linear(hidden_size, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=2)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class DiTBlock(nn.Module):
"""
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
"""
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, ffn_type="conv1d_conv1d", ffn_gated_glu=True, ffn_act_layer="gelu", ffn_conv_kernel_size=5, **block_kwargs):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
if ffn_type == "vanilla_mlp":
from timm.models.vision_transformer import Mlp
mlp_hidden_dim = int(hidden_size * mlp_ratio)
approx_gelu = lambda: nn.GELU(approximate="tanh")
self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
else:
raise NotImplementedError(f"FFN type {ffn_type} is not implemented")
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True)
)
def forward(self, x, c, seq_len, cu_seqlens, cu_maxlen, cu_seqlens_k, cu_maxlen_k, mask, rotary_pos_emb=None, incremental_state=None, nopadding=True):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=2)
x_ = modulate(self.norm1(x), shift_msa, scale_msa)
if incremental_state is not None:
if "attn_kvcache" not in incremental_state:
incremental_state["attn_kvcache"] = {}
inc_attn = incremental_state["attn_kvcache"]
else:
inc_attn = None
x_ = self.attn(x_, seq_len=seq_len, cu_seqlens=cu_seqlens, max_seqlen=cu_maxlen, cu_seqlens_k=cu_seqlens_k, max_seqlen_k=cu_maxlen_k, rotary_pos_emb=rotary_pos_emb, incremental_state=inc_attn, nopadding=nopadding)
if not nopadding:
x_ = x_ * mask[:, :, None]
x = x + gate_msa * x_
x_ = modulate(self.norm2(x), shift_mlp, scale_mlp)
x_ = self.mlp(x_)
if not nopadding:
x_ = x_ * mask[:, :, None]
x = x + gate_mlp * x_
return x