From ed242259176cf7fa7471b0ac8349e7677925ef13 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Apr 2023 23:33:03 +0300 Subject: [PATCH] ggml : optimize ggml_vec_dot_q4_1_q8_0() via vmalq_n_f32 56 ms/token with Q4_1 ! --- ggml.c | 44 +++++++++++++++++++++----------------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/ggml.c b/ggml.c index 622b029f6..2a54cd6f7 100644 --- a/ggml.c +++ b/ggml.c @@ -2446,10 +2446,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * // TODO: add AVX / WASM SIMD / etc #if defined(__ARM_NEON) - float sum00 = 0.0f; - float sum01 = 0.0f; - float sum10 = 0.0f; - float sum11 = 0.0f; + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); for (int i = 0; i < nb; i += 2) { const block_q4_1 * restrict x0 = &x[i + 0]; @@ -2480,20 +2478,24 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const int8x16_t v1_1ls = vuzp1q_s8(v1_1l, v1_1h); const int8x16_t v1_1hs = vuzp2q_s8(v1_1l, v1_1h); - // Note: cannot use vaddvq_s8 because it overflows for 8-bit values - // TODO: is there a better way to do this? - sum00 += (x0->m*y0->d)*(vaddvq_s16(vmovl_s8(vget_low_s8(v1_0ls))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_0ls))) + - vaddvq_s16(vmovl_s8(vget_low_s8(v1_0hs))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_0hs)))); - sum01 += (x1->m*y1->d)*(vaddvq_s16(vmovl_s8(vget_low_s8(v1_1ls))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_1ls))) + - vaddvq_s16(vmovl_s8(vget_low_s8(v1_1hs))) + vaddvq_s16(vmovl_s8(vget_high_s8(v1_1hs)))); + const int16x8_t s0i = vaddq_s16( + vaddq_s16(vmovl_s8(vget_low_s8(v1_0ls)), vmovl_s8(vget_high_s8(v1_0ls))), + vaddq_s16(vmovl_s8(vget_low_s8(v1_0hs)), vmovl_s8(vget_high_s8(v1_0hs)))); + + const int16x8_t s1i = vaddq_s16( + vaddq_s16(vmovl_s8(vget_low_s8(v1_1ls)), vmovl_s8(vget_high_s8(v1_1ls))), + vaddq_s16(vmovl_s8(vget_low_s8(v1_1hs)), vmovl_s8(vget_high_s8(v1_1hs)))); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s0i), vget_high_s16(s0i))), x0->m*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(s1i), vget_high_s16(s1i))), x1->m*y1->d); #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); - sum10 += (x0->d*y0->d)*vaddvq_s32(p_0); - sum11 += (x1->d*y1->d)*vaddvq_s32(p_1); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0l), vget_low_s8 (v1_0ls)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0l), vget_high_s8(v1_0ls)); @@ -2505,21 +2507,17 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1h), vget_low_s8 (v1_1hs)); const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1h), vget_high_s8(v1_1hs)); - const int16x8_t pl_0 = vaddq_s16(pl0l, pl0h); - const int16x8_t ph_0 = vaddq_s16(ph0l, ph0h); + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - const int16x8_t pl_1 = vaddq_s16(pl1l, pl1h); - const int16x8_t ph_1 = vaddq_s16(ph1l, ph1h); - - const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); - const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); - - sum10 += x0->d*y0->d*vaddvq_s16(p_0); - sum11 += x1->d*y1->d*vaddvq_s16(p_1); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0->d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), x1->d*y1->d); #endif } - sumf = sum00 + sum01 + sum10 + sum11; + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #else // scalar for (int i = 0; i < nb; i++) {