Spaces:
Paused
Paused
File size: 4,412 Bytes
c41b22c e74c01b c41b22c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import math
import torch
import torch.nn.functional as F
try:
import flash_attn
from flash_attn.flash_attn_interface import (
_flash_attn_forward,
flash_attn_func,
flash_attn_varlen_func,
)
except ImportError:
flash_attn = None
flash_attn_varlen_func = None
_flash_attn_forward = None
flash_attn_func = None
MEMORY_LAYOUT = {
# flash模式:
# 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
# 后处理: 保持形状不变
"flash": (
lambda x: x, # 保持形状
lambda x: x, # 保持形状
),
# torch/vanilla模式:
# 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
# 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
"torch": (
lambda x: x.transpose(1, 2), # (B,S,A,D) -> (B,A,S,D)
lambda x: x.transpose(1, 2), # (B,A,S,D) -> (B,S,A,D)
),
"vanilla": (
lambda x: x.transpose(1, 2),
lambda x: x.transpose(1, 2),
),
}
def attention(
q,
k,
v,
mode="torch",
drop_rate=0,
attn_mask=None,
causal=False,
):
"""
执行QKV自注意力计算
Args:
q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
drop_rate (float): 注意力矩阵的dropout概率
attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
causal (bool): 是否使用因果注意力(仅关注前面位置)
Returns:
torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
"""
# 获取预处理和后处理函数
pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]
# 应用预处理变换
q = pre_attn_layout(q) # 形状根据模式变化
k = pre_attn_layout(k)
v = pre_attn_layout(v)
if mode == "torch":
# 使用PyTorch原生的scaled_dot_product_attention
if attn_mask is not None and attn_mask.dtype != torch.bool:
attn_mask = attn_mask.to(q.dtype)
x = F.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
)
elif mode == "flash":
assert flash_attn_func is not None, "flash_attn_func未定义"
assert attn_mask is None, "不支持的注意力掩码"
x: torch.Tensor = flash_attn_func(
q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
) # type: ignore
elif mode == "vanilla":
# 手动实现注意力机制
scale_factor = 1 / math.sqrt(q.size(-1)) # 缩放因子 1/sqrt(d_k)
b, a, s, _ = q.shape # 获取形状参数
s1 = k.size(2) # 键值序列长度
# 初始化注意力偏置
attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)
# 处理因果掩码
if causal:
assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
# 生成下三角因果掩码
temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
diagonal=0
)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias = attn_bias.to(q.dtype)
# 处理自定义注意力掩码
if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask # 允许类似ALiBi的位置偏置
# 计算注意力矩阵
attn = (q @ k.transpose(-2, -1)) * scale_factor # [B,A,S,S1]
attn += attn_bias
# softmax和dropout
attn = attn.softmax(dim=-1)
attn = torch.dropout(attn, p=drop_rate, train=True)
# 计算输出
x = attn @ v # [B,A,S,D]
else:
raise NotImplementedError(f"不支持的注意力模式: {mode}")
# 应用后处理变换
x = post_attn_layout(x) # 恢复原始维度顺序
# 合并注意力头维度
b, s, a, d = x.shape
out = x.reshape(b, s, -1) # [B,S,A*D]
return out
|