From 68d793bee816e44876b22232891ee7bab51ee5e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johannes=20G=C3=A4=C3=9Fler?= Date: Mon, 1 Apr 2024 15:54:50 +0200 Subject: [PATCH] no ncols == 64 --- ggml-cuda/fattn.cu | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index aa85244fc..19108044e 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -625,9 +625,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst } int cols_per_block; - if (false && Q->ne[1] >= 128 && Q->ne[0] <= 128 && Q->ne[0] % 32 == 0) { - cols_per_block = 64; - } else if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { + if (Q->ne[1] >= 64 && (Q->ne[0] <= 128 || ggml_cuda_info().devices[ctx.device].cc >= CC_AMPERE)) { cols_per_block = 32; } else if (Q->ne[1] >= 32 || Q->ne[0] % 32 != 0) { cols_per_block = 16; @@ -645,7 +643,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(64, 8, nwarps); FATTN_SWITCH_CASE(64, 16, nwarps); FATTN_SWITCH_CASE(64, 32, nwarps); - FATTN_SWITCH_CASE(64, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -655,7 +652,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // FATTN_SWITCH_CASE(80, 8, nwarps); FATTN_SWITCH_CASE(80, 16, nwarps); FATTN_SWITCH_CASE(80, 32, nwarps); - // FATTN_SWITCH_CASE(80, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -665,7 +661,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(96, 8, nwarps); FATTN_SWITCH_CASE(96, 16, nwarps); FATTN_SWITCH_CASE(96, 32, nwarps); - FATTN_SWITCH_CASE(96, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -675,7 +670,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst // FATTN_SWITCH_CASE(112, 8, nwarps); FATTN_SWITCH_CASE(112, 16, nwarps); FATTN_SWITCH_CASE(112, 32, nwarps); - // FATTN_SWITCH_CASE(112, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -685,7 +679,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(128, 8, nwarps); FATTN_SWITCH_CASE(128, 16, nwarps); FATTN_SWITCH_CASE(128, 32, nwarps); - // FATTN_SWITCH_CASE(128, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false); @@ -695,7 +688,6 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst FATTN_SWITCH_CASE(256, 8, nwarps); FATTN_SWITCH_CASE(256, 16, nwarps); FATTN_SWITCH_CASE(256, 32, nwarps); - // FATTN_SWITCH_CASE(256, 64, nwarps); default: fprintf(stderr, "cols_per_block == %d not implemented.\n", cols_per_block); GGML_ASSERT(false);