|
#include "common.cuh" |
|
#include "fattn-common.cuh" |
|
#include "fattn-mma-f16.cuh" |
|
#include "fattn-tile-f16.cuh" |
|
#include "fattn-tile-f32.cuh" |
|
#include "fattn-vec-f16.cuh" |
|
#include "fattn-vec-f32.cuh" |
|
#include "fattn-wmma-f16.cuh" |
|
#include "fattn.cuh" |
|
|
|
template <int D, int ncols2> |
|
static void ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
|
const ggml_tensor * Q = dst->src[0]; |
|
|
|
if (Q->ne[1] <= 8/ncols2) { |
|
ggml_cuda_flash_attn_ext_mma_f16_case<D, 8/ncols2, ncols2>(ctx, dst); |
|
return; |
|
} |
|
|
|
if (Q->ne[1] <= 16/ncols2) { |
|
ggml_cuda_flash_attn_ext_mma_f16_case<D, 16/ncols2, ncols2>(ctx, dst); |
|
return; |
|
} |
|
|
|
if (Q->ne[1] <= 32/ncols2) { |
|
ggml_cuda_flash_attn_ext_mma_f16_case<D, 32/ncols2, ncols2>(ctx, dst); |
|
return; |
|
} |
|
|
|
ggml_cuda_flash_attn_ext_mma_f16_case<D, 64/ncols2, ncols2>(ctx, dst); |
|
} |
|
|
|
template <int ncols2> |
|
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
|
const ggml_tensor * Q = dst->src[0]; |
|
|
|
switch (Q->ne[0]) { |
|
case 64: |
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 64, ncols2>(ctx, dst); |
|
break; |
|
case 80: |
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 80, ncols2>(ctx, dst); |
|
break; |
|
case 96: |
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1< 96, ncols2>(ctx, dst); |
|
break; |
|
case 112: |
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<112, ncols2>(ctx, dst); |
|
break; |
|
case 128: |
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<128, ncols2>(ctx, dst); |
|
break; |
|
case 256: |
|
ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<256, ncols2>(ctx, dst); |
|
break; |
|
default: |
|
GGML_ABORT("fatal error"); |
|
break; |
|
} |
|
} |
|
|
|
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
|
const ggml_tensor * KQV = dst; |
|
const ggml_tensor * Q = dst->src[0]; |
|
const ggml_tensor * K = dst->src[1]; |
|
const ggml_tensor * mask = dst->src[3]; |
|
|
|
float max_bias = 0.0f; |
|
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); |
|
|
|
const float use_gqa_opt = mask && max_bias == 0.0f; |
|
|
|
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); |
|
const int gqa_ratio = Q->ne[2] / K->ne[2]; |
|
|
|
if (use_gqa_opt && gqa_ratio % 8 == 0) { |
|
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst); |
|
return; |
|
} |
|
|
|
if (use_gqa_opt && gqa_ratio == 4) { |
|
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<4>(ctx, dst); |
|
return; |
|
} |
|
|
|
if (use_gqa_opt && gqa_ratio == 2) { |
|
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<2>(ctx, dst); |
|
return; |
|
} |
|
|
|
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<1>(ctx, dst); |
|
} |
|
|
|
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \ |
|
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ |
|
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \ |
|
return; \ |
|
} \ |
|
|
|
static void ggml_cuda_flash_attn_ext_vec_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
|
ggml_tensor * Q = dst->src[0]; |
|
ggml_tensor * K = dst->src[1]; |
|
ggml_tensor * V = dst->src[2]; |
|
|
|
#ifdef GGML_CUDA_FA_ALL_QUANTS |
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16 ) |
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) |
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) |
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) |
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) |
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) |
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) |
|
|
|
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) |
|
#else |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) |
|
|
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) |
|
|
|
FATTN_VEC_F16_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) |
|
FATTN_VEC_F16_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) |
|
FATTN_VEC_F16_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) |
|
#endif |
|
|
|
on_no_fattn_vec_case(Q->ne[0]); |
|
} |
|
|
|
#define FATTN_VEC_F32_CASE(D, type_K, type_V) \ |
|
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \ |
|
ggml_cuda_flash_attn_ext_vec_f32_case<D, type_K, type_V>(ctx, dst); \ |
|
return; \ |
|
} \ |
|
|
|
static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
|
ggml_tensor * Q = dst->src[0]; |
|
ggml_tensor * K = dst->src[1]; |
|
ggml_tensor * V = dst->src[2]; |
|
|
|
#ifdef GGML_CUDA_FA_ALL_QUANTS |
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) |
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_0) |
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q4_1) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q4_1) |
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_0) |
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q5_1) |
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_Q8_0) |
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_F16) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_1, GGML_TYPE_F16) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_0, GGML_TYPE_F16) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q5_1, GGML_TYPE_F16) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_F16) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) |
|
|
|
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) |
|
#else |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q4_0, GGML_TYPE_Q4_0) |
|
|
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_Q8_0, GGML_TYPE_Q8_0) |
|
|
|
FATTN_VEC_F32_CASE( 64, GGML_TYPE_F16, GGML_TYPE_F16) |
|
FATTN_VEC_F32_CASE(128, GGML_TYPE_F16, GGML_TYPE_F16) |
|
FATTN_VEC_F32_CASE(256, GGML_TYPE_F16, GGML_TYPE_F16) |
|
#endif |
|
|
|
on_no_fattn_vec_case(Q->ne[0]); |
|
} |
|
|
|
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { |
|
const ggml_tensor * KQV = dst; |
|
const ggml_tensor * Q = dst->src[0]; |
|
const ggml_tensor * K = dst->src[1]; |
|
const ggml_tensor * V = dst->src[2]; |
|
const ggml_tensor * mask = dst->src[3]; |
|
|
|
ggml_cuda_set_device(ctx.device); |
|
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; |
|
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size; |
|
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV); |
|
|
|
if (cc >= GGML_CUDA_CC_OFFSET_AMD) { |
|
#if defined(GGML_HIP_ROCWMMA_FATTN) |
|
if (fp16_mma_available(cc)) { |
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); |
|
return; |
|
} |
|
#endif |
|
|
|
|
|
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { |
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); |
|
} else { |
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); |
|
} |
|
return; |
|
} |
|
|
|
if (!fast_fp16_available(cc)) { |
|
if (Q->ne[1] <= 8 || Q->ne[0] == 256) { |
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); |
|
} else { |
|
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); |
|
} |
|
return; |
|
} |
|
|
|
if (!fp16_mma_available(cc)) { |
|
if (prec == GGML_PREC_DEFAULT) { |
|
if (Q->ne[1] <= 8) { |
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); |
|
} else { |
|
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst); |
|
} |
|
} else { |
|
if (Q->ne[1] <= 8) { |
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); |
|
} else { |
|
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); |
|
} |
|
} |
|
return; |
|
} |
|
|
|
const int gqa_ratio = Q->ne[2] / K->ne[2]; |
|
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 && |
|
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask; |
|
if (Q->ne[1] == 1 && Q->ne[0] % (2*warp_size) == 0 && !mma_fast_for_bs1) { |
|
if (prec == GGML_PREC_DEFAULT) { |
|
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst); |
|
return; |
|
} else if(Q->ne[0] <= 128) { |
|
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); |
|
return; |
|
} |
|
} |
|
|
|
|
|
if (fp16_mma_available(cc) && !new_mma_available(cc)) { |
|
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst); |
|
return; |
|
} |
|
|
|
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst); |
|
} |
|
|