Spaces:
Running
on
Zero
Running
on
Zero
from einops import rearrange | |
import torch | |
from torch import Tensor | |
import torch.nn.functional as F | |
from flash_attn import flash_attn_varlen_func | |
from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa | |
def _upad_input(query_layer, key_layer, value_layer, query_mask, key_mask, query_length): | |
def _get_unpad_data(attention_mask): | |
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | |
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
max_seqlen_in_batch = seqlens_in_batch.max().item() | |
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | |
return ( | |
indices, | |
cu_seqlens, | |
max_seqlen_in_batch, | |
) | |
indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(key_mask) | |
_, q_seq_len, num_query_heads, _ = query_layer.shape | |
batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape | |
key_layer = index_first_axis( | |
key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), | |
indices_k, | |
) | |
value_layer = index_first_axis( | |
value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), | |
indices_k, | |
) | |
if query_length == kv_seq_len and key_mask is None: | |
query_layer = index_first_axis( | |
query_layer.reshape(batch_size * kv_seq_len, num_query_heads, head_dim), | |
indices_k, | |
) | |
cu_seqlens_q = cu_seqlens_k | |
max_seqlen_in_batch_q = max_seqlen_in_batch_k | |
indices_q = indices_k | |
elif query_length == 1: | |
max_seqlen_in_batch_q = 1 | |
cu_seqlens_q = torch.arange( | |
batch_size + 1, dtype=torch.int32, device=query_layer.device | |
) # There is a memcpy here, that is very bad. | |
indices_q = cu_seqlens_q[:-1] | |
query_layer = query_layer.squeeze(1) | |
else: | |
# The -q_len: slice assumes left padding. | |
query_mask = query_mask[:, -query_length:] | |
query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q, _ = unpad_input(query_layer, query_mask) | |
return ( | |
query_layer, | |
key_layer, | |
value_layer, | |
indices_q, | |
(cu_seqlens_q, cu_seqlens_k), | |
(max_seqlen_in_batch_q, max_seqlen_in_batch_k), | |
) | |
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, attn_mask: Tensor | None = None, drop_mask: Tensor | None = None) -> Tensor: | |
q, k = apply_rope(q, k, pe) | |
q = q.transpose(1, 2) | |
k = k.transpose(1, 2) | |
v = v.transpose(1, 2) | |
B, L, H, D = q.shape | |
if drop_mask is None: # todo: remove drop mask | |
drop_mask = attn_mask | |
( | |
query_states, | |
key_states, | |
value_states, | |
indices_q, | |
cu_seq_lens, | |
max_seq_lens, | |
) = _upad_input(q, k, v, attn_mask, drop_mask, L) | |
cu_seqlens_q, cu_seqlens_k = cu_seq_lens | |
max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens | |
attn_output_unpad = flash_attn_varlen_func( | |
query_states, | |
key_states, | |
value_states, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max_seqlen_in_batch_q, | |
max_seqlen_k=max_seqlen_in_batch_k, | |
dropout_p=0.0, | |
causal=False, | |
) | |
x = pad_input(attn_output_unpad, indices_q, B, L) | |
x = rearrange(x, "B L H D -> B L (H D)") | |
return x | |
def rope(pos: Tensor, dim: int, theta: int) -> Tensor: | |
assert dim % 2 == 0 | |
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim | |
omega = 1.0 / (theta**scale) | |
out = torch.einsum("...n,d->...nd", pos, omega) | |
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) | |
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) | |
return out.float() | |
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor) -> tuple[Tensor, Tensor]: | |
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) | |
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) | |
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] | |
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] | |
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) | |