kernel
File size: 744 Bytes
a7165c8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch

import flash_attn


# TODO: improve and add more tests
def test_flash_attn():
    q = torch.randn(2, 5, 4, 8)
    k = torch.randn(2, 5, 4, 8)
    v = torch.randn(2, 5, 4, 8)
    out = torch.empty(2, 5, 4, 8)
    alibi_slopes = torch.empty(4)
    p_dropout = 0.1
    softmax_scale = 1.0
    is_causal = False
    window_size_left = 0
    window_size_right = 0
    softcap = 0.0
    return_softmax = False
    gen = None

    out = flash_attn.mha_fwd(
        q,
        k,
        v,
        out,
        alibi_slopes,
        p_dropout,
        softmax_scale,
        is_causal,
        window_size_left,
        window_size_right,
        softcap,
        return_softmax,
        gen,
    )

    assert out.shape == (2, 5, 4, 8)