From 7bcc37676ad08c8574b1a8afc4d6c7ac56a86c5d Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 1 Jun 2023 11:27:14 +0300 Subject: [PATCH] A slightly faster ARM_NEON Q2_K dot Single token prediction is now ~36 ms on M2 Max. The code is much simpler too. --- k_quants.c | 77 +++++++++++++++++++++++------------------------------- 1 file changed, 33 insertions(+), 44 deletions(-) diff --git a/k_quants.c b/k_quants.c index 900552a7a..79b021cf0 100644 --- a/k_quants.c +++ b/k_quants.c @@ -40,7 +40,7 @@ #define MAX(a, b) ((a) > (b) ? (a) : (b)) // -// 3-6 bit quantization in super-blocks +// 2-6 bit quantization in super-blocks // @@ -1015,7 +1015,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t m4 = vdupq_n_u8(0xF); const int32x4_t vzero = vdupq_n_s32(0); - int8x16x4_t q2bytes; + int8x16x2_t q2bytes; uint8_t aux[16]; float sum = 0; @@ -1032,6 +1032,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri const uint8x16_t mins_and_scales = vld1q_u8(sc); const uint8x16_t scales = vandq_u8(mins_and_scales, m4); vst1q_u8(aux, scales); + const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums); const int16x8x2_t mins16 = {vmovl_u8(vget_low_u8(mins)), vmovl_u8(vget_high_u8(mins))}; @@ -1044,58 +1045,46 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri int isum = 0; int is = 0; +// We use this macro instead of a function call because for some reason +// the code runs 2-3% slower, even if the function is declared inline +#if defined(__ARM_FEATURE_DOTPROD) +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ + isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; +#else +#define MULTIPLY_ACCUM_WITH_SCALE(index)\ + {\ + const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ + vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ + const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ + vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ + isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\ + } +#endif + +#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\ + q8bytes = vld1q_s8_x2(q8); q8 += 32;\ + q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\ + q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\ + MULTIPLY_ACCUM_WITH_SCALE((index)); + + for (int j = 0; j < QK_K/128; ++j) { const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32; - const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64; - const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64; + int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32; q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3)); q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3)); - q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], 2), m3)); - q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], 2), m3)); + MULTIPLY_ACCUM_WITH_SCALE(0); -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes_1.val[0])) * aux[is+0]; - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes_1.val[1])) * aux[is+1]; - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes_1.val[2])) * aux[is+2]; - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes_1.val[3])) * aux[is+3]; -#else - int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])), - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes_1.val[0]))); - int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])), - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes_1.val[1]))); - int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])), - vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes_1.val[2]))); - int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])), - vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes_1.val[3]))); - isum += vaddvq_s16(p0) * aux[is+0] + vaddvq_s16(p1) * aux[is+1] + vaddvq_s16(p2) * aux[is+2] + vaddvq_s16(p3) * aux[is+3]; -#endif - is += 4; + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2); - q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], 4), m3)); - q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], 4), m3)); - q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], 6), m3)); - q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], 6), m3)); + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4); -#if defined(__ARM_FEATURE_DOTPROD) - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes_2.val[0])) * aux[is+0]; - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes_2.val[1])) * aux[is+1]; - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes_2.val[2])) * aux[is+2]; - isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes_2.val[3])) * aux[is+3]; -#else - p0 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])), - vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes_2.val[0]))); - p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])), - vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes_2.val[1]))); - p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])), - vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes_2.val[2]))); - p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])), - vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes_2.val[3]))); - isum += vaddvq_s16(p0) * aux[is+0] + vaddvq_s16(p1) * aux[is+1] + vaddvq_s16(p2) * aux[is+2] + vaddvq_s16(p3) * aux[is+3]; -#endif - is += 4; + SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6); + is += 8; } sum += d * isum;