File size: 4,412 Bytes
36de41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import math

import torch
import torch.nn.functional as F


try:
    import flash_attn
    from flash_attn.flash_attn_interface import (
        _flash_attn_forward,
        flash_attn_func,
        flash_attn_varlen_func,
    )
except ImportError:
    flash_attn = None
    flash_attn_varlen_func = None
    _flash_attn_forward = None
    flash_attn_func = None

MEMORY_LAYOUT = {
    # flash模式:
    # 预处理: 输入 [batch_size, seq_len, num_heads, head_dim]
    # 后处理: 保持形状不变
    "flash": (
        lambda x: x,  # 保持形状
        lambda x: x,  # 保持形状
    ),
    # torch/vanilla模式:
    # 预处理: 交换序列和注意力头的维度 [B,S,A,D] -> [B,A,S,D]
    # 后处理: 交换回原始维度 [B,A,S,D] -> [B,S,A,D]
    "torch": (
        lambda x: x.transpose(1, 2),  # (B,S,A,D) -> (B,A,S,D)
        lambda x: x.transpose(1, 2),  # (B,A,S,D) -> (B,S,A,D)
    ),
    "vanilla": (
        lambda x: x.transpose(1, 2),
        lambda x: x.transpose(1, 2),
    ),
}


def attention(
    q,
    k,
    v,
    mode="torch",
    drop_rate=0,
    attn_mask=None,
    causal=False,
):
    """
    执行QKV自注意力计算

    Args:
        q (torch.Tensor): 查询张量,形状 [batch_size, seq_len, num_heads, head_dim]
        k (torch.Tensor): 键张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
        v (torch.Tensor): 值张量,形状 [batch_size, seq_len_kv, num_heads, head_dim]
        mode (str): 注意力模式,可选 'flash', 'torch', 'vanilla'
        drop_rate (float): 注意力矩阵的dropout概率
        attn_mask (torch.Tensor): 注意力掩码,形状根据模式不同而变化
        causal (bool): 是否使用因果注意力(仅关注前面位置)

    Returns:
        torch.Tensor: 注意力输出,形状 [batch_size, seq_len, num_heads * head_dim]
    """
    # 获取预处理和后处理函数
    pre_attn_layout, post_attn_layout = MEMORY_LAYOUT[mode]

    # 应用预处理变换
    q = pre_attn_layout(q)  # 形状根据模式变化
    k = pre_attn_layout(k)
    v = pre_attn_layout(v)

    if mode == "torch":
        # 使用PyTorch原生的scaled_dot_product_attention
        if attn_mask is not None and attn_mask.dtype != torch.bool:
            attn_mask = attn_mask.to(q.dtype)
        x = F.scaled_dot_product_attention(
            q, k, v, attn_mask=attn_mask, dropout_p=drop_rate, is_causal=causal
        )
    elif mode == "flash":
        assert flash_attn_func is not None, "flash_attn_func未定义"
        assert attn_mask is None, "不支持的注意力掩码"
        x: torch.Tensor = flash_attn_func(
            q, k, v, dropout_p=drop_rate, causal=causal, softmax_scale=None
        )  # type: ignore
    elif mode == "vanilla":
        # 手动实现注意力机制
        scale_factor = 1 / math.sqrt(q.size(-1))  # 缩放因子 1/sqrt(d_k)

        b, a, s, _ = q.shape  # 获取形状参数
        s1 = k.size(2)  # 键值序列长度

        # 初始化注意力偏置
        attn_bias = torch.zeros(b, a, s, s1, dtype=q.dtype, device=q.device)

        # 处理因果掩码
        if causal:
            assert attn_mask is None, "因果掩码和注意力掩码不能同时使用"
            # 生成下三角因果掩码
            temp_mask = torch.ones(b, a, s, s, dtype=torch.bool, device=q.device).tril(
                diagonal=0
            )
            attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
            attn_bias = attn_bias.to(q.dtype)

        # 处理自定义注意力掩码
        if attn_mask is not None:
            if attn_mask.dtype == torch.bool:
                attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
            else:
                attn_bias += attn_mask  # 允许类似ALiBi的位置偏置

        # 计算注意力矩阵
        attn = (q @ k.transpose(-2, -1)) * scale_factor  # [B,A,S,S1]
        attn += attn_bias

        # softmax和dropout
        attn = attn.softmax(dim=-1)
        attn = torch.dropout(attn, p=drop_rate, train=True)

        # 计算输出
        x = attn @ v  # [B,A,S,D]
    else:
        raise NotImplementedError(f"不支持的注意力模式: {mode}")

    # 应用后处理变换
    x = post_attn_layout(x)  # 恢复原始维度顺序

    # 合并注意力头维度
    b, s, a, d = x.shape
    out = x.reshape(b, s, -1)  # [B,S,A*D]
    return out