From ef8e3ee6f5efa8067486a0c2b5ffa35d5900a6b7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 26 Apr 2023 13:58:47 +0300 Subject: [PATCH] ggml : q5_0 scalar dot product --- ggml.c | 42 +++++++++++++++++------------------------- 1 file changed, 17 insertions(+), 25 deletions(-) diff --git a/ggml.c b/ggml.c index 423b95952..f4bc9db57 100644 --- a/ggml.c +++ b/ggml.c @@ -3167,18 +3167,16 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * } static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - GGML_ASSERT(false); // TODO xxxxxxxxx - const int nb = n / QK8_1; assert(n % QK8_1 == 0); assert(nb % 2 == 0); - assert(QK8_1 == 2*QK5_0); + assert(QK8_1 == QK5_0); const block_q5_0 * restrict x = vx; const block_q8_1 * restrict y = vy; -#if defined(__ARM_NEON) +#if defined(__ARM_NEON_XXX) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); @@ -3257,43 +3255,37 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * *s = hsum_float_8(acc) + summs; #else - // scalar float sumf = 0.0; + for (int i = 0; i < nb; i++) { - const uint8_t * restrict x0 = x[2*i + 0].qs; - const uint8_t * restrict x1 = x[2*i + 1].qs; + const uint8_t * restrict x0 = x[i].qs; const int8_t * restrict y0 = y[i].qs; - const float d0 = GGML_FP16_TO_FP32(x[2*i + 0].d); - const float m0 = GGML_FP16_TO_FP32(x[2*i + 0].m); - const float d1 = GGML_FP16_TO_FP32(x[2*i + 1].d); - const float m1 = GGML_FP16_TO_FP32(x[2*i + 1].m); + const uint32_t qh = x[i].qh; - int sxy_0 = 0; - int sxy_1 = 0; + const float d = GGML_FP16_TO_FP32(x[i].d); + const float m = GGML_FP16_TO_FP32(x[i].m); - for (int j = 0; j < QK8_1/4; j++) { + int sxy = 0; + + for (int j = 0; j < QK8_1/2; j++) { const uint8_t v0 = x0[j]; - const uint8_t v1 = x1[j]; - const int x0_0 = v0 & 0x0F; - const int x1_0 = v0 >> 4; + const int x0_0h = ((qh & (1 << (2*j + 0))) >> (2*j + 0)) << 4; + const int x1_0h = ((qh & (1 << (2*j + 1))) >> (2*j + 1)) << 4; - const int x0_1 = v1 & 0x0F; - const int x1_1 = v1 >> 4; + const int x0_0 = (v0 & 0x0F) | x0_0h; + const int x1_0 = (v0 >> 4) | x1_0h; const int y0_0 = y0[2*j + 0]; const int y1_0 = y0[2*j + 1]; - const int y0_1 = y0[2*(j + QK8_1/4) + 0]; - const int y1_1 = y0[2*(j + QK8_1/4) + 1]; - - sxy_0 += x0_0*y0_0 + x1_0*y1_0; - sxy_1 += x0_1*y0_1 + x1_1*y1_1; + sxy += x0_0*y0_0 + x1_0*y1_0; } - sumf += (d0*sxy_0 + d1*sxy_1)*y[i].d + m0*y[i].s0 + m1*y[i].s1; + sumf += (d*sxy)*y[i].d + m*(y[i].s0 + y[i].s1); } + *s = sumf; #endif }