retain the tensor type as Q4_0

This commit is contained in:
Charles Xu 2024-11-07 11:06:08 +01:00
parent b632bf0fc5
commit 5947d72c84
4 changed files with 28 additions and 9 deletions

View file

@ -3540,17 +3540,14 @@ int ggml_prepare_optimal_kernel(struct ggml_tensor * cur, const void * data, siz
#if defined(__ARM_ARCH)
if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) {
repack_q4_0_to_q4_0_8_bl(cur, 8, data, data_size);
cur->type = GGML_TYPE_Q4_0_8_8;
ret = 0;
}
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
repack_q4_0_to_q4_0_4_bl(cur, 8, data, data_size);
cur->type = GGML_TYPE_Q4_0_4_8;
ret = 0;
}
else if (ggml_cpu_has_neon()) {
repack_q4_0_to_q4_0_4_bl(cur, 4, data, data_size);
cur->type = GGML_TYPE_Q4_0_4_4;
ret = 0;
}
#endif
@ -3560,4 +3557,23 @@ int ggml_prepare_optimal_kernel(struct ggml_tensor * cur, const void * data, siz
GGML_UNUSED(data);
GGML_UNUSED(data_size);
}
enum ggml_type ggml_get_optimal_type(const struct ggml_tensor * cur) {
#if defined(__ARM_ARCH)
if (cur->type == GGML_TYPE_Q4_0) {
if (ggml_cpu_has_sve() && ggml_cpu_has_matmul_int8() && ggml_cpu_get_sve_cnt() == QK8_0) {
return GGML_TYPE_Q4_0_8_8;
}
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
return GGML_TYPE_Q4_0_4_8;
}
else if (ggml_cpu_has_neon()) {
return GGML_TYPE_Q4_0_4_4;
}
}
#endif
return cur->type;
GGML_UNUSED(cur);
}
#endif

View file

@ -35,6 +35,7 @@ void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#ifdef GGML_USE_CPU_AARCH64
int ggml_prepare_optimal_kernel(struct ggml_tensor * cur, const void * data, size_t data_size);
enum ggml_type ggml_get_optimal_type(const struct ggml_tensor * cur);
#endif
#ifdef __cplusplus

View file

@ -2638,11 +2638,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
#ifdef GGML_USE_CPU_AARCH64
const struct ggml_tensor *tensor = op->src[0];
if (tensor && tensor->buffer && (strcmp(tensor->buffer->buft->iface.get_name(tensor->buffer->buft),"CPU_AARCH64") == 0)) {
if ((op->op == GGML_OP_MUL_MAT) &&
(tensor->type == GGML_TYPE_Q4_0 ||
tensor->type == GGML_TYPE_Q4_0_4_4 ||
tensor->type == GGML_TYPE_Q4_0_4_8 ||
tensor->type == GGML_TYPE_Q4_0_8_8)) {
if (op->op == GGML_OP_MUL_MAT && tensor->type == GGML_TYPE_Q4_0) {
return op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == ggml_get_type_traits_cpu(tensor->type)->vec_dot_type;
}
return false;

View file

@ -7425,7 +7425,13 @@ static void ggml_compute_forward_mul_mat(
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_type type = src0->type;
enum ggml_type type = src0->type;
#ifdef GGML_USE_CPU_AARCH64
if (strcmp(src0->buffer->buft->iface.get_name(src0->buffer->buft),"CPU_AARCH64") == 0) {
type = ggml_get_optimal_type(src0);
}
#endif
enum ggml_type const vec_dot_type = type_traits_cpu[type].vec_dot_type;
ggml_from_float_t const from_float = ggml_get_type_traits(vec_dot_type)->from_float;