From 81da919864831948f292aeb0a5bd11eb5868bdb8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 30 Mar 2024 10:34:09 +0100 Subject: [PATCH] no vec for hs, no hs==256 ncols==32 for Volta --- ggml-cuda/common.cuh | 1 + ggml-cuda/fattn.cu | 72 ++++++++++++++++++++++---------------------- 2 files changed, 37 insertions(+), 36 deletions(-) diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index 33c8ed1da..c245dd6ac 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -141,6 +141,7 @@ #define CC_PASCAL 600 #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define CC_VOLTA 700 +#define CC_AMPERE 800 #define CC_OFFSET_AMD 1000000 #define CC_RDNA1 (CC_OFFSET_AMD + 1010) #define CC_RDNA2 (CC_OFFSET_AMD + 1030) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index d34924c31..43b9a9f4a 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -463,29 +463,29 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst float scale; memcpy(&scale, KQV->op_params, sizeof(float)); - if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[1] == 1) { + if (Q->ne[0] % WARP_SIZE == 0 && Q->ne[0] >= 128 && Q->ne[1] == 1) { const int nwarps = Q->ne[0] / WARP_SIZE; const dim3 blocks_num(Q->ne[1], Q->ne[2], Q->ne[3]); const dim3 block_dim(WARP_SIZE, nwarps, 1); const int shmem = 0; switch (Q->ne[0]) { - case 64: - flash_attn_vec_ext_f16<64> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // case 64: + // flash_attn_vec_ext_f16<64> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; // case 80: // flash_attn_vec_ext_f16<80> // <<>> ( @@ -503,23 +503,23 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] // ); // break; - case 96: - flash_attn_vec_ext_f16<96> - <<>> ( - (const char *) Q->data, // Query - (const char *) K->data, // Key - (const char *) V->data, // Value - mask ? ((const char *) mask->data) : nullptr, // Mask - (float *) KQV->data, // dst - scale, - Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], - K->ne[0], K->ne[1], K->ne[2], K->ne[3], - mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, - Q->nb[1], Q->nb[2], Q->nb[3], - K->nb[1], K->nb[2], K->nb[3], - KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] - ); - break; + // case 96: + // flash_attn_vec_ext_f16<96> + // <<>> ( + // (const char *) Q->data, // Query + // (const char *) K->data, // Key + // (const char *) V->data, // Value + // mask ? ((const char *) mask->data) : nullptr, // Mask + // (float *) KQV->data, // dst + // scale, + // Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + // K->ne[0], K->ne[1], K->ne[2], K->ne[3], + // mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + // Q->nb[1], Q->nb[2], Q->nb[3], + // K->nb[1], K->nb[2], K->nb[3], + // KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + // ); + // break; // case 112: // flash_attn_vec_ext_f16<112> // <<>> ( @@ -583,7 +583,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst if (Q->ne[0] % 32 == 0) { if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { cols_per_block = 64; - } else if (Q->ne[1] >= 64) { + } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { cols_per_block = 16;