From 2ae3164d29ca8b805d47a943f987001227418196 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 11 Apr 2023 20:41:15 +0300 Subject: [PATCH] ggml : speed-up q4_1 ARM_NEON by ~5% --- ggml.c | 125 ++++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 80 insertions(+), 45 deletions(-) diff --git a/ggml.c b/ggml.c index eb47d8298..bdf3bba81 100644 --- a/ggml.c +++ b/ggml.c @@ -491,6 +491,32 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 ) } #endif +#if __ARM_NEON +#if !defined(__ARM_FEATURE_QRDMX) + +inline static int16_t vaddvq_s16(int16x8_t v) { + const int16x4_t v1 = vadd_s16(vget_low_s16(v), vget_high_s16(v)); + return vaddv_s16(v1); +} + +inline static uint16_t vaddvq_u16(uint16x8_t v) { + const uint16x4_t v1 = vadd_u16(vget_low_u16(v), vget_high_u16(v)); + return vaddv_u16(v1); +} + +inline static int32_t vaddvq_s32(int32x4_t v) { + const int32x2_t v1 = vadd_s32(vget_low_s32(v), vget_high_s32(v)); + return vaddv_s32(v1); +} + +inline static float vaddvq_f32(float32x4_t v) { + const float32x2_t v1 = vadd_f32(vget_low_f32(v), vget_high_f32(v)); + return vaddv_f32(v1); +} + +#endif +#endif + // method 5 // blocks of QK elements // represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors) @@ -1218,15 +1244,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in #define GGML_F32x4_FMA(a, b, c) vfmaq_f32(a, b, c) #define GGML_F32x4_ADD vaddq_f32 #define GGML_F32x4_MUL vmulq_f32 -#if defined(__ARM_FEATURE_QRDMX) - #define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) -#else - #define GGML_F32x4_REDUCE_ONE(x) \ - (vgetq_lane_f32(x, 0) + \ - vgetq_lane_f32(x, 1) + \ - vgetq_lane_f32(x, 2) + \ - vgetq_lane_f32(x, 3)) -#endif +#define GGML_F32x4_REDUCE_ONE(x) vaddvq_f32(x) #define GGML_F32x4_REDUCE(res, x) \ { \ for (int i = 0; i < GGML_F32_ARR/2; ++i) { \ @@ -1849,55 +1867,43 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest // 4-bit -> 8-bit const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8(v0_0, m4b)); const int8x16_t v1_0l = vreinterpretq_s8_u8(vandq_u8(v1_0, m4b)); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); const int8x16_t v1_0h = vreinterpretq_s8_u8(vshrq_n_u8(v1_0, 4)); const int8x16_t v0_1l = vreinterpretq_s8_u8(vandq_u8(v0_1, m4b)); const int8x16_t v1_1l = vreinterpretq_s8_u8(vandq_u8(v1_1, m4b)); - const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_1, 4)); const int8x16_t v1_1h = vreinterpretq_s8_u8(vshrq_n_u8(v1_1, 4)); // sub 8 const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); const int8x16_t v1_0ls = vsubq_s8(v1_0l, s8b); - const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); const int8x16_t v1_0hs = vsubq_s8(v1_0h, s8b); const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); const int8x16_t v1_1ls = vsubq_s8(v1_1l, s8b); - const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); const int8x16_t v1_1hs = vsubq_s8(v1_1h, s8b); #if defined(__ARM_FEATURE_DOTPROD) - // dot product into int16x8_t + // dot product into int32x4_t int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0ls); int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1ls); p_0 = vdotq_s32(p_0, v0_0hs, v1_0hs); p_1 = vdotq_s32(p_1, v0_1hs, v1_1hs); - // scalar -#if defined(__ARM_FEATURE_QRDMX) - sum0 += x0->d * y0->d * vaddvq_s32(p_0); - sum1 += x1->d * y1->d * vaddvq_s32(p_1); -#else - sum0 += x0->d * y0->d * (vgetq_lane_s32(p_0, 0) + vgetq_lane_s32(p_0, 1) + vgetq_lane_s32(p_0, 2) + vgetq_lane_s32(p_0, 3)); - sum1 += x1->d * y1->d * (vgetq_lane_s32(p_1, 0) + vgetq_lane_s32(p_1, 1) + vgetq_lane_s32(p_1, 2) + vgetq_lane_s32(p_1, 3)); -#endif + sum0 += x0->d*y0->d*vaddvq_s32(p_0); + sum1 += x1->d*y1->d*vaddvq_s32(p_1); #else const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0ls)); const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0ls)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0hs)); const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0hs)); const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1ls)); const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1ls)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1hs)); const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1hs)); @@ -1910,14 +1916,8 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest const int16x8_t p_0 = vaddq_s16(pl_0, ph_0); const int16x8_t p_1 = vaddq_s16(pl_1, ph_1); - // scalar -#if defined(__ARM_FEATURE_QRDMX) - sum0 += x0->d * y0->d * vaddvq_s16(p_0); - sum1 += x1->d * y1->d * vaddvq_s16(p_1); -#else - sum0 += x0->d * y0->d * (vgetq_lane_s16(p_0, 0) + vgetq_lane_s16(p_0, 1) + vgetq_lane_s16(p_0, 2) + vgetq_lane_s16(p_0, 3) + vgetq_lane_s16(p_0, 4) + vgetq_lane_s16(p_0, 5) + vgetq_lane_s16(p_0, 6) + vgetq_lane_s16(p_0, 7)); - sum1 += x1->d * y1->d * (vgetq_lane_s16(p_1, 0) + vgetq_lane_s16(p_1, 1) + vgetq_lane_s16(p_1, 2) + vgetq_lane_s16(p_1, 3) + vgetq_lane_s16(p_1, 4) + vgetq_lane_s16(p_1, 5) + vgetq_lane_s16(p_1, 6) + vgetq_lane_s16(p_1, 7)); -#endif + sum0 += x0->d*y0->d*vaddvq_s16(p_0); + sum1 += x1->d*y1->d*vaddvq_s16(p_1); #endif } @@ -2265,36 +2265,71 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest float sum10 = 0.0f; float sum11 = 0.0f; - for (int i = 0; i < nb; ++i) { + for (int i = 0; i < nb; i += 2) { const block_q4_1 * restrict x0 = &x[i + 0]; const block_q4_1 * restrict y0 = &y[i + 0]; + const block_q4_1 * restrict x1 = &x[i + 1]; + const block_q4_1 * restrict y1 = &y[i + 1]; const uint8x16_t m4b = vdupq_n_u8(0xf); const uint8x16_t v0_0 = vld1q_u8(x0->qs); const uint8x16_t v1_0 = vld1q_u8(y0->qs); + const uint8x16_t v0_1 = vld1q_u8(x1->qs); + const uint8x16_t v1_1 = vld1q_u8(y1->qs); - // and with 0xf + // 4-bit -> 8-bit const uint8x16_t v0_0l = vandq_u8(v0_0, m4b); const uint8x16_t v1_0l = vandq_u8(v1_0, m4b); - const uint8x16_t v0_0h = vshrq_n_u8(v0_0, 4); const uint8x16_t v1_0h = vshrq_n_u8(v1_0, 4); - // dot product into uint16x8_t - const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); - const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); - - const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); - const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); - - const uint16x8_t pl0 = vaddq_u16(pl0l, pl0h); - const uint16x8_t ph0 = vaddq_u16(ph0l, ph0h); + const uint8x16_t v0_1l = vandq_u8(v0_1, m4b); + const uint8x16_t v1_1l = vandq_u8(v1_1, m4b); + const uint8x16_t v0_1h = vshrq_n_u8(v0_1, 4); + const uint8x16_t v1_1h = vshrq_n_u8(v1_1, 4); sum00 += x0->m*y0->m; sum01 += y0->m*x0->d*(vaddvq_u8(v0_0l) + vaddvq_u8(v0_0h)); sum10 += x0->m*y0->d*(vaddvq_u8(v1_0l) + vaddvq_u8(v1_0h)); - sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0)); + + sum00 += x1->m*y1->m; + sum01 += y1->m*x1->d*(vaddvq_u8(v0_1l) + vaddvq_u8(v0_1h)); + sum10 += x1->m*y1->d*(vaddvq_u8(v1_1l) + vaddvq_u8(v1_1h)); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + int32x4_t p_0 = vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0l); + int32x4_t p_1 = vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1l); + + p_0 = vdotq_s32(p_0, v0_0h, v1_0h); + p_1 = vdotq_s32(p_1, v0_1h, v1_1h); + + sum11 += x0->d*y0->d*vaddvq_s32(p_0); + sum11 += x1->d*y1->d*vaddvq_s32(p_1); +#else + const uint16x8_t pl0l = vmull_u8(vget_low_u8 (v0_0l), vget_low_u8 (v1_0l)); + const uint16x8_t pl0h = vmull_u8(vget_high_u8(v0_0l), vget_high_u8(v1_0l)); + const uint16x8_t ph0l = vmull_u8(vget_low_u8 (v0_0h), vget_low_u8 (v1_0h)); + const uint16x8_t ph0h = vmull_u8(vget_high_u8(v0_0h), vget_high_u8(v1_0h)); + + const uint16x8_t pl1l = vmull_u8(vget_low_s8 (v0_1l), vget_low_u8 (v1_1l)); + const uint16x8_t pl1h = vmull_u8(vget_high_s8(v0_1l), vget_high_u8(v1_1l)); + const uint16x8_t ph1l = vmull_u8(vget_low_s8 (v0_1h), vget_low_u8 (v1_1h)); + const uint16x8_t ph1h = vmull_u8(vget_high_s8(v0_1h), vget_high_u8(v1_1h)); + + const uint16x8_t pl_0 = vaddq_u16(pl0l, pl0h); + const uint16x8_t ph_0 = vaddq_u16(ph0l, ph0h); + + const uint16x8_t pl_1 = vaddq_u16(pl1l, pl1h); + const uint16x8_t ph_1 = vaddq_u16(ph1l, ph1h); + + const uint16x8_t p_0 = vaddq_u16(pl_0, ph_0); + const uint16x8_t p_1 = vaddq_u16(pl_1, ph_1); + + sum11 += x0->d*y0->d*vaddvq_u16(p_0); + sum11 += x1->d*y1->d*vaddvq_u16(p_1); +#endif } sumf = QK*sum00 + sum01 + sum10 + sum11;