adjust kernel selection logic
This commit is contained in:
parent
81da919864
commit
269374ed81
1 changed files with 9 additions and 11 deletions
|
@ -579,9 +579,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
int cols_per_block = 16;
|
int cols_per_block;
|
||||||
if (Q->ne[0] % 32 == 0) {
|
if (Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) {
|
||||||
if (Q->ne[1] >= 128 && Q->ne[0] <= 128) {
|
|
||||||
cols_per_block = 64;
|
cols_per_block = 64;
|
||||||
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
|
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
|
||||||
cols_per_block = 32;
|
cols_per_block = 32;
|
||||||
|
@ -590,7 +589,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
} else {
|
} else {
|
||||||
cols_per_block = 8;
|
cols_per_block = 8;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
const int frag_m = cols_per_block == 8 ? 32 : 16;
|
const int frag_m = cols_per_block == 8 ? 32 : 16;
|
||||||
const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m;
|
const int nwarps = (Q->ne[0] <= 128 || cols_per_block == 8 ? Q->ne[0] : Q->ne[0]/2) / frag_m;
|
||||||
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
const dim3 blocks_num((Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue