From 269374ed818dde2b267307d62be7ba59385aebfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sun, 31 Mar 2024 16:01:27 +0200 Subject: [PATCH] adjust kernel selection logic --- ggml-cuda/fattn.cu | 20 +++++++++----------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 43b9a9f4a..f2c460086 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -579,17 +579,15 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - int cols_per_block = 16; - if (Q->ne[0] % 32 == 0) { - if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; - } + int cols_per_block; + if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m;