From d8bf7207f1ccd713aa43a7dff29012e836fe8c6d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 25 Apr 2023 22:03:08 +0300 Subject: [PATCH] ggml : finalize Q8_0 implementation --- ggml.c | 50 +++++++++++++++++++------------------------------- 1 file changed, 19 insertions(+), 31 deletions(-) diff --git a/ggml.c b/ggml.c index 5205af154..b451dd901 100644 --- a/ggml.c +++ b/ggml.c @@ -1829,7 +1829,7 @@ static void ggml_vec_dot_q4_0_q8_1(const int n, float * restrict s, const void * static void ggml_vec_dot_q4_1_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_2_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); -static void ggml_vec_dot_q8_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); +static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { [GGML_TYPE_Q4_0] = { @@ -1864,8 +1864,8 @@ static const quantize_fns_t quantize_fns[GGML_TYPE_COUNT] = { .dequantize_row_q = dequantize_row_q8_0, .quantize_row_q = quantize_row_q8_0, .quantize_row_q_reference = (quantize_row_q_t) quantize_row_q8_0_reference, - .quantize_row_q_dot = quantize_row_q8_1, - .vec_dot_q = ggml_vec_dot_q8_0_q8_1, + .quantize_row_q_dot = quantize_row_q8_0, + .vec_dot_q = ggml_vec_dot_q8_0_q8_0, }, [GGML_TYPE_Q8_1] = { .dequantize_row_q = NULL, // TODO @@ -3062,23 +3062,23 @@ static void ggml_vec_dot_q4_3_q8_1(const int n, float * restrict s, const void * #endif } -static void ggml_vec_dot_q8_0_q8_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { - const int nb = n / QK8_1; +static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const int nb = n / QK8_0; - assert(n % QK8_1 == 0); + assert(n % QK8_0 == 0); assert(nb % 2 == 0); - assert(QK8_1 == QK8_0); + assert(QK8_0 == QK8_0); const block_q8_0 * restrict x = vx; - const block_q8_1 * restrict y = vy; + const block_q8_0 * restrict y = vy; -#if defined(__ARM_NEON_XXX) +#if defined(__ARM_NEON) float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); for (int i = 0; i < nb; ++i) { const block_q8_0 * restrict x0 = &x[i]; - const block_q8_1 * restrict y0 = &y[i]; + const block_q8_0 * restrict y0 = &y[i]; const int8x16_t v0_0 = vld1q_s8(x0->qs); const int8x16_t v0_1 = vld1q_s8(x0->qs + 16); @@ -3096,28 +3096,16 @@ static void ggml_vec_dot_q8_0_q8_1(const int n, float * restrict s, const void * vdotq_s32(vdupq_n_s32(0), v0_0, v1_1), vdotq_s32(vdupq_n_s32(0), v0_1, v1_0))), x0->d*y0->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); + const int16x8_t p0l = vmull_s8(vget_low_s8 (v0_0), vget_low_s8 (v1_0)); + const int16x8_t p0h = vmull_s8(vget_high_s8(v0_0), vget_high_s8(v1_0)); + const int16x8_t p1l = vmull_s8(vget_low_s8 (v0_1), vget_low_s8 (v1_1)); + const int16x8_t p1h = vmull_s8(vget_high_s8(v0_1), vget_high_s8(v1_1)); - const int16x8_t pl1l = vmull_s8(vget_low_s8 (v0_1lz), vget_low_s8 (v1_1l)); - const int16x8_t pl1h = vmull_s8(vget_high_s8(v0_1lz), vget_high_s8(v1_1l)); - const int16x8_t ph1l = vmull_s8(vget_low_s8 (v0_1hz), vget_low_s8 (v1_1h)); - const int16x8_t ph1h = vmull_s8(vget_high_s8(v0_1hz), vget_high_s8(v1_1h)); + const int32x4_t pl = vaddq_s32(vpaddlq_s16(p0l), vpaddlq_s16(p0h)); + const int32x4_t ph = vaddq_s32(vpaddlq_s16(p1l), vpaddlq_s16(p1h)); - const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); - const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - const int32x4_t pl1 = vaddq_s32(vpaddlq_s16(pl1l), vpaddlq_s16(pl1h)); - const int32x4_t ph1 = vaddq_s32(vpaddlq_s16(ph1l), vpaddlq_s16(ph1h)); - - sumv0 = vmlaq_n_f32(sumv0, vaddq_f32( - vmulq_n_f32(vcvtq_f32_s32(pl0), GGML_FP16_TO_FP32(x0_0->d)), - vmulq_n_f32(vcvtq_f32_s32(ph0), GGML_FP16_TO_FP32(x0_1->d))), y0->d); - - sumv1 = vmlaq_n_f32(sumv1, vaddq_f32( - vmulq_n_f32(vcvtq_f32_s32(pl1), GGML_FP16_TO_FP32(x1_0->d)), - vmulq_n_f32(vcvtq_f32_s32(ph1), GGML_FP16_TO_FP32(x1_1->d))), y1->d); + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl), x0->d*y0->d); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph), x0->d*y0->d); #endif } @@ -3132,7 +3120,7 @@ static void ggml_vec_dot_q8_0_q8_1(const int n, float * restrict s, const void * int sumi = 0; - for (int j = 0; j < QK8_1; j++) { + for (int j = 0; j < QK8_0; j++) { const int v0 = x0[j]; const int v1 = y0[j];