no ncols == 64
This commit is contained in:
parent
cca6d027a3
commit
68d793bee8
1 changed files with 1 additions and 9 deletions
|
@ -625,9 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
}
|
}
|
||||||
|
|
||||||
int cols_per_block;
|
int cols_per_block;
|
||||||
if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) {
|
if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
|
||||||
cols_per_block = 64;
|
|
||||||
} else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) {
|
|
||||||
cols_per_block = 32;
|
cols_per_block = 32;
|
||||||
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
|
} else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) {
|
||||||
cols_per_block = 16;
|
cols_per_block = 16;
|
||||||
|
@ -645,7 +643,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
FATTN_SWITCH_CASE(64, 8, nwarps);
|
FATTN_SWITCH_CASE(64, 8, nwarps);
|
||||||
FATTN_SWITCH_CASE(64, 16, nwarps);
|
FATTN_SWITCH_CASE(64, 16, nwarps);
|
||||||
FATTN_SWITCH_CASE(64, 32, nwarps);
|
FATTN_SWITCH_CASE(64, 32, nwarps);
|
||||||
FATTN_SWITCH_CASE(64, 64, nwarps);
|
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -655,7 +652,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
// FATTN_SWITCH_CASE(80, 8, nwarps);
|
// FATTN_SWITCH_CASE(80, 8, nwarps);
|
||||||
FATTN_SWITCH_CASE(80, 16, nwarps);
|
FATTN_SWITCH_CASE(80, 16, nwarps);
|
||||||
FATTN_SWITCH_CASE(80, 32, nwarps);
|
FATTN_SWITCH_CASE(80, 32, nwarps);
|
||||||
// FATTN_SWITCH_CASE(80, 64, nwarps);
|
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -665,7 +661,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
FATTN_SWITCH_CASE(96, 8, nwarps);
|
FATTN_SWITCH_CASE(96, 8, nwarps);
|
||||||
FATTN_SWITCH_CASE(96, 16, nwarps);
|
FATTN_SWITCH_CASE(96, 16, nwarps);
|
||||||
FATTN_SWITCH_CASE(96, 32, nwarps);
|
FATTN_SWITCH_CASE(96, 32, nwarps);
|
||||||
FATTN_SWITCH_CASE(96, 64, nwarps);
|
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -675,7 +670,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
// FATTN_SWITCH_CASE(112, 8, nwarps);
|
// FATTN_SWITCH_CASE(112, 8, nwarps);
|
||||||
FATTN_SWITCH_CASE(112, 16, nwarps);
|
FATTN_SWITCH_CASE(112, 16, nwarps);
|
||||||
FATTN_SWITCH_CASE(112, 32, nwarps);
|
FATTN_SWITCH_CASE(112, 32, nwarps);
|
||||||
// FATTN_SWITCH_CASE(112, 64, nwarps);
|
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -685,7 +679,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
FATTN_SWITCH_CASE(128, 8, nwarps);
|
FATTN_SWITCH_CASE(128, 8, nwarps);
|
||||||
FATTN_SWITCH_CASE(128, 16, nwarps);
|
FATTN_SWITCH_CASE(128, 16, nwarps);
|
||||||
FATTN_SWITCH_CASE(128, 32, nwarps);
|
FATTN_SWITCH_CASE(128, 32, nwarps);
|
||||||
// FATTN_SWITCH_CASE(128, 64, nwarps);
|
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -695,7 +688,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
|
||||||
FATTN_SWITCH_CASE(256, 8, nwarps);
|
FATTN_SWITCH_CASE(256, 8, nwarps);
|
||||||
FATTN_SWITCH_CASE(256, 16, nwarps);
|
FATTN_SWITCH_CASE(256, 16, nwarps);
|
||||||
FATTN_SWITCH_CASE(256, 32, nwarps);
|
FATTN_SWITCH_CASE(256, 32, nwarps);
|
||||||
// FATTN_SWITCH_CASE(256, 64, nwarps);
|
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block);
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue