diff --git a/ggml.c b/ggml.c index e937ce12b..443683873 100644 --- a/ggml.c +++ b/ggml.c @@ -656,10 +656,11 @@ static_assert(sizeof(block_q4_3) == 2 * sizeof(ggml_fp16_t) + QK4_3 / 2, "wrong #define QK8_0 32 typedef struct { float d; // delta - float s; // d * sum(qs[i]) + float s0; // d * sum(qs[i]) low + float s1; // d * sum(qs[i]) high int8_t qs[QK8_0]; // quants } block_q8_0; -static_assert(sizeof(block_q8_0) == 2*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); +static_assert(sizeof(block_q8_0) == 3*sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); // reference implementation for deterministic creation of model files @@ -1299,13 +1300,22 @@ static void quantize_row_q8_0_reference(const float * restrict x, block_q8_0 * r y[i].d = d; - int sum = 0; - for (int l = 0; l < QK8_0; ++l) { - const float v = x[i*QK8_0 + l]*id; - y[i].qs[l] = roundf(v); - sum += y[i].qs[l]; + int sum0 = 0; + int sum1 = 0; + + for (int l = 0; l < QK8_0/2; ++l) { + const float v0 = x[i*QK8_0 + l]*id; + const float v1 = x[i*QK8_0 + QK8_0/2 + l]*id; + + y[i].qs[ l] = roundf(v0); + y[i].qs[QK8_0/2 + l] = roundf(v1); + + sum0 += y[i].qs[ l]; + sum1 += y[i].qs[QK8_0/2 + l]; } - y[i].s = d * sum; + + y[i].s0 = d * sum0; + y[i].s1 = d * sum1; } } @@ -1335,9 +1345,11 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].d = d; - int32x4_t accv = vdupq_n_s32(0); + int32x4_t accv0 = vdupq_n_s32(0); + int32x4_t accv1 = vdupq_n_s32(0); - for (int l = 0; l < 8; l++) { + // low half + for (int l = 0; l < 4; l++) { const float32x4_t v = vmulq_n_f32(srcv[l], id); const int32x4_t vi = vcvtnq_s32_f32(v); @@ -1346,12 +1358,30 @@ static void quantize_row_q8_0(const float * restrict x, void * restrict vy, int y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); - accv = vaddq_s32(accv, vi); + accv0 = vaddq_s32(accv0, vi); } - int32_t sum = vaddvq_s32(accv); - y[i].s = d * sum; + + // high half + for (int l = 4; l < 8; l++) { + const float32x4_t v = vmulq_n_f32(srcv[l], id); + const int32x4_t vi = vcvtnq_s32_f32(v); + + y[i].qs[4*l + 0] = vgetq_lane_s32(vi, 0); + y[i].qs[4*l + 1] = vgetq_lane_s32(vi, 1); + y[i].qs[4*l + 2] = vgetq_lane_s32(vi, 2); + y[i].qs[4*l + 3] = vgetq_lane_s32(vi, 3); + + accv1 = vaddq_s32(accv1, vi); + } + + const int32_t sum0 = vaddvq_s32(accv0); + const int32_t sum1 = vaddvq_s32(accv1); + + y[i].s0 = d * sum0; + y[i].s1 = d * sum1; } #elif defined(__AVX2__) || defined(__AVX__) + // TODO !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! for (int i = 0; i < nb; i++) { // Load elements into 4 AVX vectors __m256 v0 = _mm256_loadu_ps( x ); @@ -2395,7 +2425,7 @@ static void ggml_vec_dot_q4_0_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - sum8 += x0->d * y0->s + x1->d * y1->s; + sum8 += x0->d * (y0->s0 + y0->s1) + x1->d * (y1->s0 + y1->s1); const uint8x16_t m4b = vdupq_n_u8(0xf); @@ -2562,7 +2592,7 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; - summs += x0->m * y0->s + x1->m * y1->s; + summs += x0->m * (y0->s0 + y0->s1) + x1->m * (y1->s0 + y1->s1); const uint8x16_t m4b = vdupq_n_u8(0xf); @@ -2589,8 +2619,8 @@ static void ggml_vec_dot_q4_1_q8_0(const int n, float * restrict s, const void * #if defined(__ARM_FEATURE_DOTPROD) // dot product into int32x4_t - const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0l, v1_0ls), v0_0h, v1_0hs); - const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1l, v1_1ls), v0_1h, v1_1hs); + const int32x4_t p_0 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l), v0_0hz, v1_0h); + const int32x4_t p_1 = vdotq_s32(vdotq_s32(vdupq_n_s32(0), v0_1lz, v1_1l), v0_1hz, v1_1h); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(p_0), x0->d*y0->d); sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(p_1), x1->d*y1->d); @@ -2845,6 +2875,8 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * float32x4_t sumv0 = vdupq_n_f32(0.0f); float32x4_t sumv1 = vdupq_n_f32(0.0f); + float summs = 0.0f; + for (int i = 0; i < nb; i += 2) { const block_q4_3 * restrict x0_0 = &x[2*(i + 0) + 0]; const block_q4_3 * restrict x0_1 = &x[2*(i + 0) + 1]; @@ -2854,6 +2886,9 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const block_q8_0 * restrict y0 = &y[i + 0]; const block_q8_0 * restrict y1 = &y[i + 1]; + summs += GGML_FP16_TO_FP32(x0_0->m) * y0->s0 + GGML_FP16_TO_FP32(x0_1->m) * y0->s1; + summs += GGML_FP16_TO_FP32(x1_0->m) * y1->s0 + GGML_FP16_TO_FP32(x1_1->m) * y1->s1; + const uint8x16_t m4b = vdupq_n_u8(0xf); const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); @@ -2861,11 +2896,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const float x1_0d = GGML_FP16_TO_FP32(x1_0->d); const float x1_1d = GGML_FP16_TO_FP32(x1_1->d); - const float x0_0m = GGML_FP16_TO_FP32(x0_0->m); - const float x0_1m = GGML_FP16_TO_FP32(x0_1->m); - const float x1_0m = GGML_FP16_TO_FP32(x1_0->m); - const float x1_1m = GGML_FP16_TO_FP32(x1_1->m); - const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); const uint8x16_t v0_1 = vcombine_u8(vld1_u8(x1_0->qs), vld1_u8(x1_1->qs)); @@ -2887,17 +2917,6 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * const int8x16_t v1_1l = vld1q_s8(y1->qs); const int8x16_t v1_1h = vld1q_s8(y1->qs + 16); - const int16x8_t sy0_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0l)), vmovl_s8(vget_high_s8(v1_0l))); - const int16x8_t sy0_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_0h)), vmovl_s8(vget_high_s8(v1_0h))); - - const int16x8_t sy1_0 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1l)), vmovl_s8(vget_high_s8(v1_1l))); - const int16x8_t sy1_1 = vaddq_s16(vmovl_s8(vget_low_s8(v1_1h)), vmovl_s8(vget_high_s8(v1_1h))); - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_0), vget_high_s16(sy0_0))), x0_0m*y0->d); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy0_1), vget_high_s16(sy0_1))), x0_1m*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_0), vget_high_s16(sy1_0))), x1_0m*y1->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vaddl_s16(vget_low_s16(sy1_1), vget_high_s16(sy1_1))), x1_1m*y1->d); - #if defined(__ARM_FEATURE_DOTPROD) sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); @@ -2926,7 +2945,7 @@ static void ggml_vec_dot_q4_3_q8_0(const int n, float * restrict s, const void * #endif } - *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + sumf = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps();