[metal] (HACK!!!) force use kernel_flash_attn_ext_scalar_f16 in FA
This commit is contained in:
parent
9e62e7e10e
commit
d436f5ba2c
1 changed files with 66 additions and 25 deletions
|
@ -206,6 +206,14 @@ enum ggml_metal_kernel_type {
|
|||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||
|
@ -702,6 +710,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
|||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32, flash_attn_ext_scalar_f16_h32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64, flash_attn_ext_scalar_f16_h64, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96, flash_attn_ext_scalar_f16_h96, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128,flash_attn_ext_scalar_f16_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32, flash_attn_ext_scalar_q8_0_h32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64, flash_attn_ext_scalar_q8_0_h64, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96, flash_attn_ext_scalar_q8_0_h96, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128,flash_attn_ext_scalar_q8_0_h128, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||
|
@ -852,15 +868,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
|||
case GGML_OP_LEAKY_RELU:
|
||||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
if (op->src[1]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[2]->type != GGML_TYPE_F16) {
|
||||
return false;
|
||||
}
|
||||
if (op->src[0]->ne[0] == 256) {
|
||||
return false;
|
||||
}
|
||||
// if (op->src[1]->type != GGML_TYPE_F16) {
|
||||
// return false;
|
||||
// }
|
||||
// if (op->src[2]->type != GGML_TYPE_F16) {
|
||||
// return false;
|
||||
// }
|
||||
// if (op->src[0]->ne[0] == 256) {
|
||||
// return false;
|
||||
// }
|
||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_SSM_CONV:
|
||||
case GGML_OP_SSM_SCAN:
|
||||
|
@ -2765,6 +2781,8 @@ static void ggml_metal_encode_node(
|
|||
GGML_ASSERT(ne11 % 32 == 0);
|
||||
|
||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||
// K, V shall have the same type
|
||||
GGML_ASSERT(src1->type == src2->type);
|
||||
|
||||
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
||||
|
||||
|
@ -2811,33 +2829,56 @@ static void ggml_metal_encode_node(
|
|||
|
||||
bool use_vec_kernel = false;
|
||||
|
||||
if (ne01 >= 4 || (ne00%128 != 0)) {
|
||||
if (false) {
|
||||
switch (ne00) {
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
use_vec_kernel = true;
|
||||
|
||||
switch (ne00) {
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
if (src1->type == GGML_TYPE_F16) {
|
||||
switch (ne00) {
|
||||
case 32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128].pipeline; break;
|
||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} else if (src1->type == GGML_TYPE_Q8_0) {
|
||||
switch (ne00) {
|
||||
case 32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32].pipeline; break;
|
||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64].pipeline; break;
|
||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96].pipeline; break;
|
||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128].pipeline; break;
|
||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||
default:
|
||||
{
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
} else {
|
||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||
GGML_ABORT("add template specialization for this size");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue