From 13264fa067e200fe891977d48862ef610ad24daa Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 30 May 2023 14:18:47 +0300 Subject: [PATCH] Adding Q3_K dot for ARM_NEON It is 22% slower than Q4_K, despite the smaller model size. On x86_64, where we are memory bound, the Q3_K model is quite a bit faster than Q4_K. --- k_quants.c | 120 ++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 2 deletions(-) diff --git a/k_quants.c b/k_quants.c index 07fca9d5b..74e95f4e4 100644 --- a/k_quants.c +++ b/k_quants.c @@ -892,8 +892,124 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri const int nb = n / QK_K; -#ifdef z__ARM_NEON - // TODO +#ifdef __ARM_NEON + + uint32_t aux[3]; + uint32_t utmp[4]; + + const uint8x16_t m3b = vdupq_n_u8(0x3); + const int32x4_t vzero = vdupq_n_s32(0); + + const uint8x16_t m0 = vdupq_n_u8(1); + const uint8x16_t m1 = vshlq_n_u8(m0, 1); + const uint8x16_t m2 = vshlq_n_u8(m0, 2); + const uint8x16_t m3 = vshlq_n_u8(m0, 3); + + int8x16x4_t q3bytes; + + float sum = 0; + + for (int i = 0; i < nb; ++i) { + + const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); + + const uint8_t * restrict q3 = x[i].qs; + const uint8_t * restrict qh = x[i].hmask; + const int8_t * restrict q8 = y[i].qs; + + uint8x16x2_t qhbits = vld1q_u8_x2(qh); + + uint8x16x4_t q3h; + + int32_t isum = 0; + + // Set up scales + memcpy(aux, x[i].scales, 12); + utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4); + utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4); + utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4); + utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4); + + const int8_t * scale = (const int8_t *)utmp; + + for (int j = 0; j < QK_K/128; ++j) { + + const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32; + const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; + const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; + + q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2); + q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2); + q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1); + q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), q3h.val[0]); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), q3h.val[1]); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), q3h.val[2]); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), q3h.val[3]); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * (scale[0] - 32); + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * (scale[1] - 32); + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * (scale[2] - 32); + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * (scale[3] - 32); +#else + int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); + int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); + int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); + int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); + isum += vaddvq_s16(p0) * (scale[0] - 32) + + vaddvq_s16(p1) * (scale[1] - 32) + + vaddvq_s16(p2) * (scale[2] - 32) + + vaddvq_s16(p3) * (scale[3] - 32); +#endif + scale += 4; + + q3h.val[0] = vbicq_u8(m2, qhbits.val[0]); + q3h.val[1] = vbicq_u8(m2, qhbits.val[1]); + q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1); + q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1); + + qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4); + qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4); + + q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), q3h.val[0]); + q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), q3h.val[1]); + q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), q3h.val[2]); + q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), q3h.val[3]); + +#if defined(__ARM_FEATURE_DOTPROD) + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * (scale[0] - 32); + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * (scale[1] - 32); + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * (scale[2] - 32); + isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * (scale[3] - 32); +#else + p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), + vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); + p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), + vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); + p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), + vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); + p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), + vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); + isum += vaddvq_s16(p0) * (scale[0] - 32) + + vaddvq_s16(p1) * (scale[1] - 32) + + vaddvq_s16(p2) * (scale[2] - 32) + + vaddvq_s16(p3) * (scale[3] - 32); +#endif + scale += 4; + + } + sum += d * isum; + + } + + *s = sum; + #elif defined __AVX2__ const __m256i m3 = _mm256_set1_epi8(3);