From d436f5ba2c59bc8033e2eefe6d2a1d862abb2506 Mon Sep 17 00:00:00 2001 From: Shupei Fan Date: Tue, 24 Sep 2024 12:33:32 +0800 Subject: [PATCH] [metal] (HACK!!!) force use kernel_flash_attn_ext_scalar_f16 in FA --- ggml/src/ggml-metal.m | 91 +++++++++++++++++++++++++++++++------------ 1 file changed, 66 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 8ff16983e..1b51a2518 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -206,6 +206,14 @@ enum ggml_metal_kernel_type { //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, //GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261 + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F16_F16, @@ -702,6 +710,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) { //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction); //GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32, flash_attn_ext_scalar_f16_h32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64, flash_attn_ext_scalar_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96, flash_attn_ext_scalar_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128,flash_attn_ext_scalar_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32, flash_attn_ext_scalar_q8_0_h32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64, flash_attn_ext_scalar_q8_0_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96, flash_attn_ext_scalar_q8_0_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128,flash_attn_ext_scalar_q8_0_h128, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true); @@ -852,15 +868,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx case GGML_OP_LEAKY_RELU: return true; case GGML_OP_FLASH_ATTN_EXT: - if (op->src[1]->type != GGML_TYPE_F16) { - return false; - } - if (op->src[2]->type != GGML_TYPE_F16) { - return false; - } - if (op->src[0]->ne[0] == 256) { - return false; - } + // if (op->src[1]->type != GGML_TYPE_F16) { + // return false; + // } + // if (op->src[2]->type != GGML_TYPE_F16) { + // return false; + // } + // if (op->src[0]->ne[0] == 256) { + // return false; + // } return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels case GGML_OP_SSM_CONV: case GGML_OP_SSM_SCAN: @@ -2765,6 +2781,8 @@ static void ggml_metal_encode_node( GGML_ASSERT(ne11 % 32 == 0); GGML_ASSERT(src0->type == GGML_TYPE_F32); + // K, V shall have the same type + GGML_ASSERT(src1->type == src2->type); GGML_ASSERT(ggml_are_same_shape (src1, src2)); @@ -2811,33 +2829,56 @@ static void ggml_metal_encode_node( bool use_vec_kernel = false; - if (ne01 >= 4 || (ne00%128 != 0)) { + if (false) { switch (ne00) { case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } } } else { use_vec_kernel = true; - switch (ne00) { - case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break; - //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; - default: - { - GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); - GGML_METAL_LOG_ERROR("add template specialization for this size\n"); - GGML_ABORT("add template specialization for this size"); - } + if (src1->type == GGML_TYPE_F16) { + switch (ne00) { + case 32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32].pipeline; break; + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } else if (src1->type == GGML_TYPE_Q8_0) { + switch (ne00) { + case 32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32].pipeline; break; + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128].pipeline; break; + //case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); + } + } + } else { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ABORT("add template specialization for this size"); } }