From 7b72318e6f1733b03c82f7a7c33227b6b40852db Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Wed, 3 Jan 2024 15:35:24 +0100 Subject: [PATCH] iq2_xxs: ARM_NEON dot product Somehow strangely slow (112 ms/token). --- ggml-quants.c | 43 ++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 42 insertions(+), 1 deletion(-) diff --git a/ggml-quants.c b/ggml-quants.c index 6e66ed841..24e39c73f 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -7243,7 +7243,48 @@ void ggml_vec_dot_iq2_xxs_q8_K(const int n, float * restrict s, const void * res const int nb = n / QK_K; -#if defined __AVX2__ +#if defined(__ARM_NEON) + + const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs; + + uint32_t aux32[4]; + const uint8_t * aux8 = (const uint8_t *)aux32; + + int8x16x4_t q2u; + int8x16x4_t q2s; + int8x16x4_t q8b; + + float sumf = 0; + for (int i = 0; i < nb; ++i) { + const float d = GGML_FP16_TO_FP32(x[i].d) * y[i].d; + const uint16_t * restrict q2 = x[i].qs; + const int8_t * restrict q8 = y[i].qs; + float sumf1 = 0, sumf2 = 0; + for (int ib32 = 0; ib32 < QK_K/32; ib32 += 2) { + q8b = vld1q_s8_x4(q8); q8 += 64; + memcpy(aux32, q2, 4*sizeof(uint32_t)); q2 += 8; + q2u.val[0] = vcombine_u8(vld1_u8((const void *)(iq2xxs_grid + aux8[ 0])), vld1_u8((const void *)(iq2xxs_grid + aux8[ 1]))); + q2u.val[1] = vcombine_u8(vld1_u8((const void *)(iq2xxs_grid + aux8[ 2])), vld1_u8((const void *)(iq2xxs_grid + aux8[ 3]))); + q2u.val[2] = vcombine_u8(vld1_u8((const void *)(iq2xxs_grid + aux8[ 8])), vld1_u8((const void *)(iq2xxs_grid + aux8[ 9]))); + q2u.val[3] = vcombine_u8(vld1_u8((const void *)(iq2xxs_grid + aux8[10])), vld1_u8((const void *)(iq2xxs_grid + aux8[11]))); + q2s.val[0] = vcombine_u8(vld1_u8((const void *)(signs64 + ((aux32[1] >> 0) & 127))), vld1_u8((const void *)(signs64 + ((aux32[1] >> 7) & 127)))); + q2s.val[1] = vcombine_u8(vld1_u8((const void *)(signs64 + ((aux32[1] >> 14) & 127))), vld1_u8((const void *)(signs64 + ((aux32[1] >> 21) & 127)))); + q2s.val[2] = vcombine_u8(vld1_u8((const void *)(signs64 + ((aux32[3] >> 0) & 127))), vld1_u8((const void *)(signs64 + ((aux32[3] >> 7) & 127)))); + q2s.val[3] = vcombine_u8(vld1_u8((const void *)(signs64 + ((aux32[3] >> 14) & 127))), vld1_u8((const void *)(signs64 + ((aux32[3] >> 21) & 127)))); + q2u.val[0] = vmulq_s8(q2u.val[0], q2s.val[0]); + q2u.val[1] = vmulq_s8(q2u.val[1], q2s.val[1]); + q2u.val[2] = vmulq_s8(q2u.val[2], q2s.val[2]); + q2u.val[3] = vmulq_s8(q2u.val[3], q2s.val[3]); + const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[0], q8b.val[0]), q2u.val[1], q8b.val[1]); + const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q2u.val[2], q8b.val[2]), q2u.val[3], q8b.val[3]); + sumf1 += vaddvq_s32(p1) * (0.5f + (aux32[1] >> 28)); + sumf2 += vaddvq_s32(p2) * (0.5f + (aux32[3] >> 28)); + } + sumf += d*(sumf1 + sumf2); + } + *s = 0.25f * sumf; + +#elif defined(__AVX2__) const uint64_t * signs64 = (const uint64_t *)keven_signs_q2xs;