[metal] (HACK!!!) force use kernel_flash_attn_ext_scalar_f16 in FA
This commit is contained in:
parent
9e62e7e10e
commit
d436f5ba2c
1 changed files with 66 additions and 25 deletions
|
@ -206,6 +206,14 @@ enum ggml_metal_kernel_type {
|
||||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||||
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
//GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, // https://github.com/ggerganov/llama.cpp/issues/7261
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F16_F16,
|
||||||
|
@ -702,6 +710,14 @@ static struct ggml_backend_metal_context * ggml_metal_init(void) {
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128, flash_attn_ext_vec_f16_h128, ctx->support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256, flash_attn_ext_vec_f16_h256, ctx->support_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32, flash_attn_ext_scalar_f16_h32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64, flash_attn_ext_scalar_f16_h64, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96, flash_attn_ext_scalar_f16_h96, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128,flash_attn_ext_scalar_f16_h128, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32, flash_attn_ext_scalar_q8_0_h32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64, flash_attn_ext_scalar_q8_0_h64, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96, flash_attn_ext_scalar_q8_0_h96, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128,flash_attn_ext_scalar_q8_0_h128, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||||
|
@ -852,15 +868,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_context * ctx
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_FLASH_ATTN_EXT:
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
if (op->src[1]->type != GGML_TYPE_F16) {
|
// if (op->src[1]->type != GGML_TYPE_F16) {
|
||||||
return false;
|
// return false;
|
||||||
}
|
// }
|
||||||
if (op->src[2]->type != GGML_TYPE_F16) {
|
// if (op->src[2]->type != GGML_TYPE_F16) {
|
||||||
return false;
|
// return false;
|
||||||
}
|
// }
|
||||||
if (op->src[0]->ne[0] == 256) {
|
// if (op->src[0]->ne[0] == 256) {
|
||||||
return false;
|
// return false;
|
||||||
}
|
// }
|
||||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||||
case GGML_OP_SSM_CONV:
|
case GGML_OP_SSM_CONV:
|
||||||
case GGML_OP_SSM_SCAN:
|
case GGML_OP_SSM_SCAN:
|
||||||
|
@ -2765,6 +2781,8 @@ static void ggml_metal_encode_node(
|
||||||
GGML_ASSERT(ne11 % 32 == 0);
|
GGML_ASSERT(ne11 % 32 == 0);
|
||||||
|
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
// K, V shall have the same type
|
||||||
|
GGML_ASSERT(src1->type == src2->type);
|
||||||
|
|
||||||
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
GGML_ASSERT(ggml_are_same_shape (src1, src2));
|
||||||
|
|
||||||
|
@ -2811,33 +2829,56 @@ static void ggml_metal_encode_node(
|
||||||
|
|
||||||
bool use_vec_kernel = false;
|
bool use_vec_kernel = false;
|
||||||
|
|
||||||
if (ne01 >= 4 || (ne00%128 != 0)) {
|
if (false) {
|
||||||
switch (ne00) {
|
switch (ne00) {
|
||||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||||
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||||
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||||
GGML_ABORT("add template specialization for this size");
|
GGML_ABORT("add template specialization for this size");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
use_vec_kernel = true;
|
use_vec_kernel = true;
|
||||||
|
|
||||||
switch (ne00) {
|
if (src1->type == GGML_TYPE_F16) {
|
||||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
switch (ne00) {
|
||||||
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
case 32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H32].pipeline; break;
|
||||||
default:
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H64].pipeline; break;
|
||||||
{
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H96].pipeline; break;
|
||||||
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_F16_H128].pipeline; break;
|
||||||
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||||
GGML_ABORT("add template specialization for this size");
|
default:
|
||||||
}
|
{
|
||||||
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (src1->type == GGML_TYPE_Q8_0) {
|
||||||
|
switch (ne00) {
|
||||||
|
case 32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H32].pipeline; break;
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H64].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H96].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_SCALAR_Q8_0_H128].pipeline; break;
|
||||||
|
//case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ABORT("add template specialization for this size");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue