From 1b49d26f8a1a862e0b628c4828e53f7f5315ebf2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?H=C3=A5kon=20H=2E=20Hitland?= Date: Fri, 21 Apr 2023 00:11:49 +0200 Subject: [PATCH] q4_0c: Arm Neon acceleration Mostly copied from the q4_0 implementation --- ggml.c | 96 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 94 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 03c9cd462..78abf324c 100644 --- a/ggml.c +++ b/ggml.c @@ -1758,7 +1758,37 @@ static void quantize_row_q8_0c(const float * restrict x, void * restrict vy, int int8_t * restrict qs = vy; float * restrict ds = (float *) ((uint8_t *) vy + nb*QK8_0C); -#if __AVX512F__ +#if defined(__ARM_NEON) + for (int i = 0; i < nb; i++) { + float32x4_t srcv [8]; + float32x4_t asrcv[8]; + float32x4_t amaxv[8]; + + for (int l = 0; l < 8; l++) srcv[l] = vld1q_f32(x + i*32 + 4*l); + for (int l = 0; l < 8; l++) asrcv[l] = vabsq_f32(srcv[l]); + + for (int l = 0; l < 4; l++) amaxv[2*l] = vmaxq_f32(asrcv[2*l], asrcv[2*l+1]); + for (int l = 0; l < 2; l++) amaxv[4*l] = vmaxq_f32(amaxv[4*l], amaxv[4*l+2]); + for (int l = 0; l < 1; l++) amaxv[8*l] = vmaxq_f32(amaxv[8*l], amaxv[8*l+4]); + + const float amax = vmaxvq_f32(amaxv[0]); + + const float d = amax / ((1 << 7) - 1); + const float id = d ? 1.0f/d : 0.0f; + + ds[i] = d; + + for (int l = 0; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + qs[i*QK8_0C + 4*l + 0] = vgetq_lane_s32(vi, 0); + qs[i*QK8_0C + 4*l + 1] = vgetq_lane_s32(vi, 1); + qs[i*QK8_0C + 4*l + 2] = vgetq_lane_s32(vi, 2); + qs[i*QK8_0C + 4*l + 3] = vgetq_lane_s32(vi, 3); + } + } +#elif defined(__AVX512F__) for (int i = 0; i < nb; i++) { const __m512 x0 = _mm512_loadu_ps( x + i*QK8_0C ); const __m512 x1 = _mm512_loadu_ps( x + i*QK8_0C + QK8_0C/2); @@ -3095,7 +3125,69 @@ static void ggml_vec_dot_q4_0c_q8_0c(const int n, float * restrict s, const void float sumf = 0.0; -#if __AVX512F__ +#if defined(__ARM_NEON) + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + + for (int i = 0; i < nb/2; i++) { + const int dst0 = i + i/2*2; // 0, 1, 4, 5, 8, 9, ... + const int dst1 = i + i/2*2 + 2; // 2, 3, 6, 7, 10, 11 ... + + const uint8x16_t m4b = vdupq_n_u8(0xf); + const int8x16_t s8b = vdupq_n_s8(0x8); + + const uint8x16_t v0_01l = vld1q_u8(&xqs[i*QK4_0]); + const uint8x16_t v0_01h = vld1q_u8(&xqs[i*QK4_0 + QK4_0/2]); + + // 4-bit -> 8-bit + const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_01l, m4b)); + const int8x16_t v0_0h = vreinterpretq_s8_u8(vandq_u8 (v0_01h, m4b)); + const int8x16_t v0_1l = vreinterpretq_s8_u8(vshrq_n_u8(v0_01l, 4)); + const int8x16_t v0_1h = vreinterpretq_s8_u8(vshrq_n_u8(v0_01h, 4)); + + // sub 8 + const int8x16_t v0_0ls = vsubq_s8(v0_0l, s8b); + const int8x16_t v0_0hs = vsubq_s8(v0_0h, s8b); + const int8x16_t v0_1ls = vsubq_s8(v0_1l, s8b); + const int8x16_t v0_1hs = vsubq_s8(v0_1h, s8b); + + // load y + const int8x16_t v1_0l = vld1q_s8(&yqs[dst0*QK8_0C]); + const int8x16_t v1_0h = vld1q_s8(&yqs[dst0*QK8_0C + 16]); + const int8x16_t v1_1l = vld1q_s8(&yqs[dst1*QK8_0C]); + const int8x16_t v1_1h = vld1q_s8(&yqs[dst1*QK8_0C + 16]); + +#if defined(__ARM_FEATURE_DOTPROD) + // dot product into int32x4_t + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0ls, v1_0l), v0_0hs, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1ls, v1_1l), v0_1hs, v1_1h); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), xds[dst0]*yds[dst0]); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), xds[dst1]*yds[dst1]); +#else + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0ls), vget_low_s8 (v1_0l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0ls), vget_high_s8(v1_0l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hs), vget_low_s8 (v1_0h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hs), vget_high_s8(v1_0h)); + + const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1ls), vget_low_s8 (v1_1l)); + const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1ls), vget_high_s8(v1_1l)); + const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hs), vget_low_s8 (v1_1h)); + const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hs), vget_high_s8(v1_1h)); + + const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); + const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); + const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); + const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), xds[dst0]*yds[dst0]); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddq_s32(pl1, ph1)), xds[dst1]*yds[dst1]); +#endif + } + + sumf = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + +#elif defined(__AVX512F__) // Initialize accumulator with zeros __m512 acc = _mm512_setzero_ps(); for (int i = 0; i < nb; i += 4) {