diff --git a/ggml/src/ggml-aarch64.c b/ggml/src/ggml-aarch64.c index 2305d08b2..801cf2bdc 100644 --- a/ggml/src/ggml-aarch64.c +++ b/ggml/src/ggml-aarch64.c @@ -3540,17 +3540,14 @@ int ggml_prepare_optimal_kernel(struct ggml_tensor * cur, const void * data, siz #if defined(__ARM_ARCH) if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size); - cur->type = GGML_TYPE_Q4_0_8_8; ret = 0; } else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size); - cur->type = GGML_TYPE_Q4_0_4_8; ret = 0; } else if (ggml_cpu_has_neon()) { repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size); - cur->type = GGML_TYPE_Q4_0_4_4; ret = 0; } #endif @@ -3560,4 +3557,23 @@ int ggml_prepare_optimal_kernel(struct ggml_tensor * cur, const void * data, siz GGML_UNUSED(data); GGML_UNUSED(data_size); } + +enum ggml_type ggml_get_optimal_type(const struct ggml_tensor * cur) { +#if defined(__ARM_ARCH) + if (cur->type == GGML_TYPE_Q4_0) { + if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) { + return GGML_TYPE_Q4_0_8_8; + } + else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { + return GGML_TYPE_Q4_0_4_8; + } + else if (ggml_cpu_has_neon()) { + return GGML_TYPE_Q4_0_4_4; + } + } +#endif + return cur->type; + + GGML_UNUSED(cur); +} #endif diff --git a/ggml/src/ggml-aarch64.h b/ggml/src/ggml-aarch64.h index 61860fcfb..0353c6be4 100644 --- a/ggml/src/ggml-aarch64.h +++ b/ggml/src/ggml-aarch64.h @@ -35,6 +35,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo #ifdef GGML_USE_CPU_AARCH64 int ggml_prepare_optimal_kernel(struct ggml_tensor * cur, const void * data, size_t data_size); +enum ggml_type ggml_get_optimal_type(const struct ggml_tensor * cur); #endif #ifdef __cplusplus diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 959860424..b21a92a76 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -2638,11 +2638,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st #ifdef GGML_USE_CPU_AARCH64 const struct ggml_tensor *tensor = op->src[0]; if (tensor && tensor->buffer && (strcmp(tensor->buffer->buft->iface.get_name(tensor->buffer->buft),"CPU_AARCH64") == 0)) { - if ((op->op == GGML_OP_MUL_MAT) && - (tensor->type == GGML_TYPE_Q4_0 || - tensor->type == GGML_TYPE_Q4_0_4_4 || - tensor->type == GGML_TYPE_Q4_0_4_8 || - tensor->type == GGML_TYPE_Q4_0_8_8)) { + if (op->op == GGML_OP_MUL_MAT && tensor->type == GGML_TYPE_Q4_0) { return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_get_type_traits_cpu(tensor->type)->vec_dot_type; } return false; diff --git a/ggml/src/ggml-cpu.c b/ggml/src/ggml-cpu.c index de1de18ec..b62fd3413 100644 --- a/ggml/src/ggml-cpu.c +++ b/ggml/src/ggml-cpu.c @@ -7425,7 +7425,13 @@ static void ggml_compute_forward_mul_mat( const int ith = params->ith; const int nth = params->nth; - const enum ggml_type type = src0->type; + enum ggml_type type = src0->type; + +#ifdef GGML_USE_CPU_AARCH64 + if (strcmp(src0->buffer->buft->iface.get_name(src0->buffer->buft),"CPU_AARCH64") == 0) { + type = ggml_get_optimal_type(src0); + } +#endif enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type; ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float;