A slightly faster ARM_NEON Q2_K dot
Single token prediction is now ~36 ms on M2 Max. The code is much simpler too.
This commit is contained in:
parent
6ec70579cb
commit
7bcc37676a
1 changed files with 33 additions and 44 deletions
77
k_quants.c
77
k_quants.c
|
@ -40,7 +40,7 @@
|
||||||
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
|
||||||
//
|
//
|
||||||
// 3-6 bit quantization in super-blocks
|
// 2-6 bit quantization in super-blocks
|
||||||
//
|
//
|
||||||
|
|
||||||
|
|
||||||
|
@ -1015,7 +1015,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
const uint8x16_t m4 = vdupq_n_u8(0xF);
|
||||||
const int32x4_t vzero = vdupq_n_s32(0);
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
|
|
||||||
int8x16x4_t q2bytes;
|
int8x16x2_t q2bytes;
|
||||||
uint8_t aux[16];
|
uint8_t aux[16];
|
||||||
|
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
|
@ -1032,6 +1032,7 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
const uint8x16_t mins_and_scales = vld1q_u8(sc);
|
const uint8x16_t mins_and_scales = vld1q_u8(sc);
|
||||||
const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
|
const uint8x16_t scales = vandq_u8(mins_and_scales, m4);
|
||||||
vst1q_u8(aux, scales);
|
vst1q_u8(aux, scales);
|
||||||
|
|
||||||
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4);
|
||||||
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
|
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
|
||||||
const int16x8x2_t mins16 = {vmovl_u8(vget_low_u8(mins)), vmovl_u8(vget_high_u8(mins))};
|
const int16x8x2_t mins16 = {vmovl_u8(vget_low_u8(mins)), vmovl_u8(vget_high_u8(mins))};
|
||||||
|
@ -1044,58 +1045,46 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
int isum = 0;
|
int isum = 0;
|
||||||
int is = 0;
|
int is = 0;
|
||||||
|
|
||||||
|
// We use this macro instead of a function call because for some reason
|
||||||
|
// the code runs 2-3% slower, even if the function is declared inline
|
||||||
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
||||||
|
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\
|
||||||
|
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)];
|
||||||
|
#else
|
||||||
|
#define MULTIPLY_ACCUM_WITH_SCALE(index)\
|
||||||
|
{\
|
||||||
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\
|
||||||
|
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\
|
||||||
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\
|
||||||
|
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\
|
||||||
|
isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
|
||||||
|
q8bytes = vld1q_s8_x2(q8); q8 += 32;\
|
||||||
|
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
|
||||||
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
|
||||||
|
MULTIPLY_ACCUM_WITH_SCALE((index));
|
||||||
|
|
||||||
|
|
||||||
for (int j = 0; j < QK_K/128; ++j) {
|
for (int j = 0; j < QK_K/128; ++j) {
|
||||||
|
|
||||||
const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
|
const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
|
||||||
const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
|
|
||||||
const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
|
|
||||||
|
|
||||||
|
int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
|
||||||
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
|
||||||
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
|
||||||
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], 2), m3));
|
MULTIPLY_ACCUM_WITH_SCALE(0);
|
||||||
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], 2), m3));
|
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_DOTPROD)
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes_1.val[0])) * aux[is+0];
|
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes_1.val[1])) * aux[is+1];
|
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes_1.val[2])) * aux[is+2];
|
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes_1.val[3])) * aux[is+3];
|
|
||||||
#else
|
|
||||||
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
|
|
||||||
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
|
|
||||||
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
|
|
||||||
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
|
|
||||||
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
|
|
||||||
vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
|
|
||||||
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
|
|
||||||
vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
|
|
||||||
isum += vaddvq_s16(p0) * aux[is+0] + vaddvq_s16(p1) * aux[is+1] + vaddvq_s16(p2) * aux[is+2] + vaddvq_s16(p3) * aux[is+3];
|
|
||||||
#endif
|
|
||||||
is += 4;
|
|
||||||
|
|
||||||
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], 4), m3));
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
|
||||||
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], 4), m3));
|
|
||||||
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], 6), m3));
|
|
||||||
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], 6), m3));
|
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_DOTPROD)
|
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes_2.val[0])) * aux[is+0];
|
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes_2.val[1])) * aux[is+1];
|
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes_2.val[2])) * aux[is+2];
|
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes_2.val[3])) * aux[is+3];
|
|
||||||
#else
|
|
||||||
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
|
|
||||||
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
|
|
||||||
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
|
|
||||||
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
|
|
||||||
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
|
|
||||||
vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
|
|
||||||
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
|
|
||||||
vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
|
|
||||||
isum += vaddvq_s16(p0) * aux[is+0] + vaddvq_s16(p1) * aux[is+1] + vaddvq_s16(p2) * aux[is+2] + vaddvq_s16(p3) * aux[is+3];
|
|
||||||
#endif
|
|
||||||
is += 4;
|
|
||||||
|
|
||||||
|
is += 8;
|
||||||
}
|
}
|
||||||
sum += d * isum;
|
sum += d * isum;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue