diff --git a/ggml.c b/ggml.c index f4bc9db57..6abd1cf90 100644 --- a/ggml.c +++ b/ggml.c @@ -3176,57 +3176,79 @@ static void ggml_vec_dot_q5_0_q8_1(const int n, float * restrict s, const void * const block_q5_0 * restrict x = vx; const block_q8_1 * restrict y = vy; -#if defined(__ARM_NEON_XXX) - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); +#if defined(__ARM_NEON) + float32x4_t sumv = vdupq_n_f32(0.0f); - float summs0 = 0.0f; - float summs1 = 0.0f; + float summs = 0.0f; + + uint32_t tmp[8]; + + static const uint32_t k_mask[16] = { + 0x00000000, 0x00000010, 0x00001000, 0x00001010, + 0x00100000, 0x00100010, 0x00101000, 0x00101010, + 0x10000000, 0x10000010, 0x10001000, 0x10001010, + 0x10100000, 0x10100010, 0x10101000, 0x10101010, + }; for (int i = 0; i < nb; ++i) { - const block_q5_0 * restrict x0_0 = &x[2*(i + 0) + 0]; - const block_q5_0 * restrict x0_1 = &x[2*(i + 0) + 1]; + const block_q5_0 * restrict x0 = &x[i]; + const block_q8_1 * restrict y0 = &y[i]; - const block_q8_1 * restrict y0 = &y[i + 0]; + summs += GGML_FP16_TO_FP32(x0->m) * (y0->s0 + y0->s1); - summs0 += GGML_FP16_TO_FP32(x0_0->m) * y0->s0; - summs1 += GGML_FP16_TO_FP32(x0_1->m) * y0->s1; + // extract the 5th bit + const uint32_t qh = x0->qh; - const uint8x16_t v0_0 = vcombine_u8(vld1_u8(x0_0->qs), vld1_u8(x0_1->qs)); + tmp[0] = k_mask[(qh >> 0) & 0x0F]; + tmp[1] = k_mask[(qh >> 4) & 0x0F]; + tmp[2] = k_mask[(qh >> 8) & 0x0F]; + tmp[3] = k_mask[(qh >> 12) & 0x0F]; + tmp[4] = k_mask[(qh >> 16) & 0x0F]; + tmp[5] = k_mask[(qh >> 20) & 0x0F]; + tmp[6] = k_mask[(qh >> 24) & 0x0F]; + tmp[7] = k_mask[(qh >> 28)]; + + const int8x16_t qhl = vld1q_s8((const int8_t *)(tmp + 0)); + const int8x16_t qhh = vld1q_s8((const int8_t *)(tmp + 4)); + + const uint8x16_t v0 = vld1q_u8(x0->qs); // 4-bit -> 8-bit - const int8x16_t v0_0l = vreinterpretq_s8_u8(vandq_u8 (v0_0, vdupq_n_u8(0x0F))); - const int8x16_t v0_0h = vreinterpretq_s8_u8(vshrq_n_u8(v0_0, 4)); + const int8x16_t v0l = vreinterpretq_s8_u8(vandq_u8 (v0, vdupq_n_u8(0x0F))); + const int8x16_t v0h = vreinterpretq_s8_u8(vshrq_n_u8(v0, 4)); // interleave - const int8x16_t v0_0lz = vzip1q_s8(v0_0l, v0_0h); - const int8x16_t v0_0hz = vzip2q_s8(v0_0l, v0_0h); + const int8x16_t v0lz = vzip1q_s8(v0l, v0h); + const int8x16_t v0hz = vzip2q_s8(v0l, v0h); + + // add + const int8x16_t v0lf = vorrq_s8(v0lz, qhl); + const int8x16_t v0hf = vorrq_s8(v0hz, qhh); // load y - const int8x16_t v1_0l = vld1q_s8(y0->qs); - const int8x16_t v1_0h = vld1q_s8(y0->qs + 16); + const int8x16_t v1l = vld1q_s8(y0->qs); + const int8x16_t v1h = vld1q_s8(y0->qs + 16); - const float x0_0d = GGML_FP16_TO_FP32(x0_0->d); - const float x0_1d = GGML_FP16_TO_FP32(x0_1->d); + const float x0d = GGML_FP16_TO_FP32(x0->d); #if defined(__ARM_FEATURE_DOTPROD) - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0lz, v1_0l)), x0_0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(vdotq_s32(vdupq_n_s32(0), v0_0hz, v1_0h)), x0_1d*y0->d); + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32( + vdotq_s32(vdupq_n_s32(0), v0lf, v1l), + vdotq_s32(vdupq_n_s32(0), v0hf, v1h))), x0d*y0->d); #else - const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0_0lz), vget_low_s8 (v1_0l)); - const int16x8_t pl0h = vmull_s8(vget_high_s8(v0_0lz), vget_high_s8(v1_0l)); - const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0_0hz), vget_low_s8 (v1_0h)); - const int16x8_t ph0h = vmull_s8(vget_high_s8(v0_0hz), vget_high_s8(v1_0h)); + const int16x8_t pl0l = vmull_s8(vget_low_s8 (v0lf), vget_low_s8 (v1l)); + const int16x8_t pl0h = vmull_s8(vget_high_s8(v0lf), vget_high_s8(v1l)); + const int16x8_t ph0l = vmull_s8(vget_low_s8 (v0hf), vget_low_s8 (v1h)); + const int16x8_t ph0h = vmull_s8(vget_high_s8(v0hf), vget_high_s8(v1h)); const int32x4_t pl0 = vaddq_s32(vpaddlq_s16(pl0l), vpaddlq_s16(pl0h)); const int32x4_t ph0 = vaddq_s32(vpaddlq_s16(ph0l), vpaddlq_s16(ph0h)); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(pl0), x0_0d*y0->d); - sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(ph0), x0_1d*y0->d); + sumv = vmlaq_n_f32(sumv, vcvtq_f32_s32(vaddq_s32(pl0, ph0)), x0d*y0->d); #endif } - *s = vaddvq_f32(vaddq_f32(sumv0, sumv1)) + summs0 + summs1; + *s = vaddvq_f32(sumv) + summs; #elif defined(__AVX2__) // Initialize accumulator with zeros __m256 acc = _mm256_setzero_ps();