From d59ac670bf92f18ba9db44f37fef93b002528ac8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Sat, 30 Mar 2024 09:19:19 +0100 Subject: [PATCH] 16 cols for Phi-2 --- ggml-cuda/fattn.cu | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index ccb3c9246..d34924c31 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -579,15 +579,17 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - int cols_per_block; - if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; + int cols_per_block = 16; + if (Q->ne[0] % 32 == 0) { + if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; + } } const int frag_m = cols_per_block == 8 ? 32 : 16; const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m;