From 53e81ca28944f0d86910a30ee959f2edfa24277b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 24 Jun 2023 18:32:38 +0300 Subject: [PATCH] k_quants: 10% faster ARM_NEON Q5_K dot product --- k_quants.c | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/k_quants.c b/k_quants.c index d913b8760..46dd884b0 100644 --- a/k_quants.c +++ b/k_quants.c @@ -2769,8 +2769,6 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const float d = y[i].d * (float)x[i].d; const int8_t * sc = x[i].scales; - sumf -= 16.f * d * (sc[0] * y[i].bsums[0] + sc[1] * y[i].bsums[1] + sc[2] * y[i].bsums[2] + sc[3] * y[i].bsums[3]); - const uint8_t * restrict q5 = x[i].qs; const uint8_t * restrict qh = x[i].qh; const int8_t * restrict q8 = y[i].qs; @@ -2781,15 +2779,15 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const int8x16x4_t q8bytes = vld1q_s8_x4(q8); const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); - q5h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 4)); - q5h.val[1] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); - q5h.val[2] = vandq_u8(mh, htmp); - q5h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); + q5h.val[0] = vbicq_u8(mh, vshlq_n_u8(htmp, 4)); + q5h.val[1] = vbicq_u8(mh, vshlq_n_u8(htmp, 2)); + q5h.val[2] = vbicq_u8(mh, htmp); + q5h.val[3] = vbicq_u8(mh, vshrq_n_u8(htmp, 2)); - q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); - q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); - q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2])); - q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3])); + q5bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[0], m4b)), vreinterpretq_s8_u8(q5h.val[0])); + q5bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q5bits.val[1], m4b)), vreinterpretq_s8_u8(q5h.val[1])); + q5bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[0], 4)), vreinterpretq_s8_u8(q5h.val[2])); + q5bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vshrq_n_u8(q5bits.val[1], 4)), vreinterpretq_s8_u8(q5h.val[3])); #if defined(__ARM_FEATURE_DOTPROD)