diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index 43b9a9f4a..f2c460086 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -579,17 +579,15 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst return; } - int cols_per_block = 16; - if (Q->ne[0] % 32 == 0) { - if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { - cols_per_block = 32; - } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { - cols_per_block = 16; - } else { - cols_per_block = 8; - } + int cols_per_block; + if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { + cols_per_block = 64; + } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { + cols_per_block = 32; + } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { + cols_per_block = 16; + } else { + cols_per_block = 8; } const int frag_m = cols_per_block == 8 ? 32 : 16; const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m;