From d92c5a9e2978bcd18cde42875f92abe8b94cfd7b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 23 Jun 2023 08:43:04 +0300 Subject: [PATCH] k_quants: WIP super-blocks with 64 weights Another small improvement for Q3_K and Q5_K on ARM_NEON --- k_quants.c | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/k_quants.c b/k_quants.c index 3c3928590..b43286f4c 100644 --- a/k_quants.c +++ b/k_quants.c @@ -1903,10 +1903,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const float d = y[i].d * (float)x[i].d; - q3h.val[0] = vandq_u8(mh, vcombine_u8(vshl_n_u8(hbits, 2), vshl_n_u8(hbits, 1))); - q3h.val[1] = vandq_u8(mh, vcombine_u8(hbits, vshr_n_u8(hbits, 1))); - q3h.val[2] = vandq_u8(mh, vcombine_u8(vshr_n_u8(hbits, 2), vshr_n_u8(hbits, 3))); - q3h.val[3] = vandq_u8(mh, vcombine_u8(vshr_n_u8(hbits, 4), vshr_n_u8(hbits, 5))); + const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1)); + q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); + q3h.val[1] = vandq_u8(mh, htmp); + q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); + q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4)); q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); @@ -2705,10 +2706,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16x2_t q5bits = vld1q_u8_x2(q5); const int8x16x4_t q8bytes = vld1q_s8_x4(q8); - q5h.val[0] = vandq_u8(mh, vcombine_u8(vshl_n_u8(qhbits, 4), vshl_n_u8(qhbits, 3))); - q5h.val[1] = vandq_u8(mh, vcombine_u8(vshl_n_u8(qhbits, 2), vshl_n_u8(qhbits, 1))); - q5h.val[2] = vandq_u8(mh, vcombine_u8(qhbits, vshr_n_u8(qhbits, 1))); - q5h.val[3] = vandq_u8(mh, vcombine_u8(vshr_n_u8(qhbits, 2), vshr_n_u8(qhbits, 3))); + const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)); + q5h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 4)); + q5h.val[1] = vandq_u8(mh, vshlq_n_u8(htmp, 2)); + q5h.val[2] = vandq_u8(mh, htmp); + q5h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 2)); q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));