adjust kernel selection logic

This commit is contained in:
Johannes Gäßler 2024-03-31 16:01:27 +02:00 committed by Georgi Gerganov
parent 81da919864
commit 269374ed81

View file

@ -579,17 +579,15 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return; return;
} }
int cols_per_block = 16; int cols_per_block;
if (Q->ne[0] % 32 == 0) { if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) {
if (Q->ne[1] >= 128 && Q->ne[0] <= 128) { cols_per_block = 64;
cols_per_block = 64; } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32;
cols_per_block = 32; } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { cols_per_block = 16;
cols_per_block = 16; } else {
} else { cols_per_block = 8;
cols_per_block = 8;
}
} }
const int frag_m = cols_per_block == 8 ? 32 : 16; const int frag_m = cols_per_block == 8 ? 32 : 16;
const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m; const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m;