File size: 744 Bytes
a7165c8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
import flash_attn
# TODO: improve and add more tests
def test_flash_attn():
q = torch.randn(2, 5, 4, 8)
k = torch.randn(2, 5, 4, 8)
v = torch.randn(2, 5, 4, 8)
out = torch.empty(2, 5, 4, 8)
alibi_slopes = torch.empty(4)
p_dropout = 0.1
softmax_scale = 1.0
is_causal = False
window_size_left = 0
window_size_right = 0
softcap = 0.0
return_softmax = False
gen = None
out = flash_attn.mha_fwd(
q,
k,
v,
out,
alibi_slopes,
p_dropout,
softmax_scale,
is_causal,
window_size_left,
window_size_right,
softcap,
return_softmax,
gen,
)
assert out.shape == (2, 5, 4, 8)
|