k_quants: WIP super-blocks with 64 weights

Q2_K working on ARM_NEON, but quite a bit slower than 256 weights
This commit is contained in:
Iwan Kawrakow 2023-06-22 18:16:08 +03:00
parent cda47a6b2f
commit 80c75fe821

View file

@ -1417,81 +1417,61 @@ void ggml_vec_dot_q2_K_q8_K(const int n, float * restrict s, const void * restri
#ifdef __ARM_NEON #ifdef __ARM_NEON
const uint8x16_t m3 = vdupq_n_u8(0x3); const uint8x16_t m3 = vdupq_n_u8(0x3);
const uint8x16_t m4 = vdupq_n_u8(0xF);
const int32x4_t vzero = vdupq_n_s32(0); const int32x4_t vzero = vdupq_n_s32(0);
int8x16x2_t q2bytes; int8x16x4_t q2bytes;
uint8_t aux[16];
uint32_t aux32[2];
const uint8_t * scales = (const uint8_t *)aux32;
float sum = 0; float sum = 0;
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d); const float d = y[i].d * (float)x[i].d;
const float dmin = -y[i].d * ggml_fp16_to_fp32(x[i].dmin); const float dmin = -y[i].d * (float)x[i].dmin;
const uint8_t * restrict q2 = x[i].qs; const uint8_t * restrict q2 = x[i].qs;
const int8_t * restrict q8 = y[i].qs; const int8_t * restrict q8 = y[i].qs;
const uint8_t * restrict sc = x[i].scales; const uint32_t * restrict sc = (const uint32_t *)x[i].scales;
const uint8x16_t mins_and_scales = vld1q_u8(sc); aux32[0] = sc[0] & 0x0f0f0f0f;
const uint8x16_t scales = vandq_u8(mins_and_scales, m4); aux32[1] = (sc[0] >> 4) & 0x0f0f0f0f;
vst1q_u8(aux, scales);
const uint8x16_t mins = vshrq_n_u8(mins_and_scales, 4); sum += dmin * (scales[4] * y[i].bsums[0] + scales[5] * y[i].bsums[1] + scales[6] * y[i].bsums[2] + scales[7] * y[i].bsums[3]);
const int16x8x2_t q8sums = vld1q_s16_x2(y[i].bsums);
const int16x8x2_t mins16 = {vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(mins))), vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(mins)))};
const int32x4_t s0 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[0]), vget_low_s16 (q8sums.val[0])),
vmull_s16(vget_high_s16(mins16.val[0]), vget_high_s16(q8sums.val[0])));
const int32x4_t s1 = vaddq_s32(vmull_s16(vget_low_s16 (mins16.val[1]), vget_low_s16 (q8sums.val[1])),
vmull_s16(vget_high_s16(mins16.val[1]), vget_high_s16(q8sums.val[1])));
sum += dmin * vaddvq_s32(vaddq_s32(s0, s1));
int isum = 0; int isum1 = 0, isum2 = 0;
int is = 0;
const uint8x16_t q2bits = vld1q_u8(q2);
const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits, m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 2), m3));
q2bytes.val[2] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 4), m3));
q2bytes.val[3] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits, 6), m3));
// We use this macro instead of a function call because for some reason
// the code runs 2-3% slower, even if the function is declared inline
#if defined(__ARM_FEATURE_DOTPROD) #if defined(__ARM_FEATURE_DOTPROD)
#define MULTIPLY_ACCUM_WITH_SCALE(index)\ isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * scales[0];
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[0], q8bytes.val[0])) * aux[is+(index)];\ isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * scales[1];
isum += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[1], q8bytes.val[1])) * aux[is+1+(index)]; isum1 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[2], q8bytes.val[2])) * scales[2];
isum2 += vaddvq_s32(vdotq_s32(vzero, q2bytes.val[3], q8bytes.val[3])) * scales[3];
#else #else
#define MULTIPLY_ACCUM_WITH_SCALE(index)\ const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
{\ vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[0]), vget_low_s8 (q8bytes.val[0])),\ const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q2bytes.val[0]), vget_high_s8(q8bytes.val[0])));\ vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[1]), vget_low_s8 (q8bytes.val[1])),\ isum1 += vaddvq_s16(p1) * scales[0];
vmull_s8(vget_high_s8(q2bytes.val[1]), vget_high_s8(q8bytes.val[1])));\ isum2 += vaddvq_s16(p2) * scales[1];
isum += vaddvq_s16(p1) * aux[is+(index)] + vaddvq_s16(p2) * aux[is+1+(index)];\
} const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q2bytes.val[2]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p4 = vaddq_s16(vmull_s8(vget_low_s8 (q2bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q2bytes.val[3]), vget_high_s8(q8bytes.val[3])));
isum1 += vaddvq_s16(p3) * scales[2];
isum2 += vaddvq_s16(p4) * scales[3];
#endif #endif
sum += d * (isum1 + isum2);
#define SHIFT_MULTIPLY_ACCUM_WITH_SCALE(shift, index)\
q8bytes = vld1q_s8_x2(q8); q8 += 32;\
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[0], (shift)), m3));\
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q2bits.val[1], (shift)), m3));\
MULTIPLY_ACCUM_WITH_SCALE((index));
for (int j = 0; j < QK_K/128; ++j) {
const uint8x16x2_t q2bits = vld1q_u8_x2(q2); q2 += 32;
int8x16x2_t q8bytes = vld1q_s8_x2(q8); q8 += 32;
q2bytes.val[0] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[0], m3));
q2bytes.val[1] = vreinterpretq_s8_u8(vandq_u8(q2bits.val[1], m3));
MULTIPLY_ACCUM_WITH_SCALE(0);
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(2, 2);
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(4, 4);
SHIFT_MULTIPLY_ACCUM_WITH_SCALE(6, 6);
is += 8;
}
sum += d * isum;
} }