From 3a7908940f6defe0d5fbc5b7f985699bd87b29b0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 18 Apr 2023 22:39:35 +0300 Subject: [PATCH] ggml : speed-up q4_2 - 4 threads: ~100ms -> ~90ms - 8 threads: ~55ms -> ~50ms --- ggml.c | 35 ++++++++++++++++++++++++----------- 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/ggml.c b/ggml.c index 24cfd0009..854e251a6 100644 --- a/ggml.c +++ b/ggml.c @@ -3057,8 +3057,8 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * float sumf = 0.0; #if defined(__ARM_NEON) - float sum0 = 0.0f; - float sum1 = 0.0f; + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); for (int i = 0; i < nb; i += 2) { const block_q4_2 * restrict x0_0 = &x[2*(i + 0) + 0]; @@ -3099,10 +3099,21 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); #if defined(__ARM_FEATURE_DOTPROD) - sum0 += (GGML_FP16_TO_FP32(x0_0->d)*y0->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)); - sum0 += (GGML_FP16_TO_FP32(x0_1->d)*y0->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)); - sum1 += (GGML_FP16_TO_FP32(x1_0->d)*y1->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l)); - sum1 += (GGML_FP16_TO_FP32(x1_1->d)*y1->d)*vaddvq_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)); + const float32x4_t x0_0d = vdupq_n_f32(GGML_FP16_TO_FP32(x0_0->d)); + const float32x4_t x0_1d = vdupq_n_f32(GGML_FP16_TO_FP32(x0_1->d)); + const float32x4_t x1_0d = vdupq_n_f32(GGML_FP16_TO_FP32(x1_0->d)); + const float32x4_t x1_1d = vdupq_n_f32(GGML_FP16_TO_FP32(x1_1->d)); + + const float32x4_t y0d = vdupq_n_f32(y0->d); + const float32x4_t y1d = vdupq_n_f32(y1->d); + + sumv0 = vaddq_f32(sumv0, vmulq_f32(y0d, vaddq_f32( + vmulq_f32(x0_0d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l))), + vmulq_f32(x0_1d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)))))); + + sumv1 = vaddq_f32(sumv1, vmulq_f32(y1d, vaddq_f32( + vmulq_f32(x1_0d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l))), + vmulq_f32(x1_1d, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_1hz, v1_1h)))))); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); @@ -3119,14 +3130,16 @@ static void ggml_vec_dot_q4_2_q8_0(const int n, float * restrict s, const void * const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - sum0 += (GGML_FP16_TO_FP32(x0_0->d)*y0->d)*vaddvq_s32(pl0); - sum0 += (GGML_FP16_TO_FP32(x0_1->d)*y0->d)*vaddvq_s32(ph0); - sum1 += (GGML_FP16_TO_FP32(x1_0->d)*y1->d)*vaddvq_s32(pl1); - sum1 += (GGML_FP16_TO_FP32(x1_1->d)*y1->d)*vaddvq_s32(ph1); + sumv0 = vaddq_f32(sumv0, vmulq_f32(vdupq_n_f32(y0->d), vaddq_f32( + vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x0_0->d)), vcvtq_f32_s32(pl0)), + vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x0_1->d)), vcvtq_f32_s32(ph0))))); + sumv1 = vaddq_f32(sumv1, vmulq_f32(vdupq_n_f32(y1->d), vaddq_f32( + vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x1_0->d)), vcvtq_f32_s32(pl1)), + vmulq_f32(vdupq_n_f32(GGML_FP16_TO_FP32(x1_1->d)), vcvtq_f32_s32(ph1))))); #endif } - sumf = sum0 + sum1; + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); #else // scalar for (int i = 0; i < nb; i++) {