k_quants: WIP super-blocks with 64 weights

Q4_K working on ARM_NEON, but quite a bit slower than 256 weights
This commit is contained in:
Iwan Kawrakow 2023-06-22 16:31:51 +03:00
parent 03f30c8eca
commit cda47a6b2f

View file

@ -2363,92 +2363,71 @@ void ggml_vec_dot_q4_K_q8_K(const int n, float * restrict s, const void * restri
const int nb = n / QK_K;
#ifdef z__ARM_NEON
#ifdef __ARM_NEON
const uint8x16_t m4b = vdupq_n_u8(0xf);
#ifdef __ARM_FEATURE_DOTPROD
const int32x4_t mzero = vdupq_n_s32(0);
#endif
int8x16x2_t q4bytes;
int8x16x2_t q8bytes;
float sumf = 0;
int8x16x2_t q4bytes;
int8x16x4_t q8bytes;
float sum_mins = 0.f;
for (int i = 0; i < nb; ++i) {
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
const float dmin = y[i].d * ggml_fp16_to_fp32(x[i].dmin);
const int16x8_t q8sums = vpaddq_s16(vld1q_s16(y[i].bsums), vld1q_s16(y[i].bsums + 8));
memcpy(utmp, x[i].scales, 12);
const uint32x2_t mins8 = {utmp[1] & kmask1, ((utmp[2] >> 4) & kmask2) | (((utmp[1] >> 6) & kmask3) << 4)};
utmp[1] = (utmp[2] & kmask2) | (((utmp[0] >> 6) & kmask3) << 4);
utmp[0] &= kmask1;
const int16x8_t mins = vreinterpretq_s16_u16(vmovl_u8(vreinterpret_u8_u32(mins8)));
const int32x4_t prod = vaddq_s32(vmull_s16(vget_low_s16 (q8sums), vget_low_s16 (mins)),
vmull_s16(vget_high_s16(q8sums), vget_high_s16(mins)));
sumf -= dmin * vaddvq_s32(prod);
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
//int32x4_t isum = mzero;
const float32x4_t dsc = vcvt_f32_f16(vld1_f16(x[i].d));
float summ = vgetq_lane_f32(dsc, 1) * (y[i].bsums[0] + y[i].bsums[1])
+ vgetq_lane_f32(dsc, 3) * (y[i].bsums[2] + y[i].bsums[3]);
sum_mins += y[i].d * summ;
int32_t sumi1 = 0;
int32_t sumi2 = 0;
for (int j = 0; j < QK_K/64; ++j) {
const uint8x16x2_t q4bits = vld1q_u8_x2(q4); q4 += 32;
const uint8x16x2_t q4bits = vld1q_u8_x2(q4);
#ifdef __ARM_FEATURE_DOTPROD
q8bytes = vld1q_s8_x2(q8); q8 += 32;
q8bytes = vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int32x4_t p1 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
sumi1 += vaddvq_s32(p1) * scales[2*j+0];
const float sumf1 = vaddvq_s32(p1) * vgetq_lane_f32(dsc, 0);
q8bytes = vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[0]), q4bytes.val[1], q8bytes.val[1]);
const int32x4_t p2 = vdotq_s32(vdotq_s32(mzero, q4bytes.val[0], q8bytes.val[2]), q4bytes.val[1], q8bytes.val[3]);
const float sumf2 = vaddvq_s32(p2) * vgetq_lane_f32(dsc, 2);
sumi2 += vaddvq_s32(p2) * scales[2*j+1];
#else
q8bytes = vld1q_s8_x2(q8); q8 += 32;
q8bytes = vld1q_s8_x4(q8);
q4bytes.val[0] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[0], m4b));
q4bytes.val[1] = vreinterpretq_s8_u8(vandq_u8 (q4bits.val[1], m4b));
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi1 += vaddvq_s16(vaddq_s16(p0, p1)) * scales[2*j+0];
float sumf1 = vaddvq_s16(vaddq_s16(p0, p1)) * vgetq_lane_f32(dsc, 0);
q8bytes = vld1q_s8_x2(q8); q8 += 32;
q4bytes.val[0] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[0], 4));
q4bytes.val[1] = vreinterpretq_s8_u8(vshrq_n_u8(q4bits.val[1], 4));
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[0])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[1])));
sumi2 += vaddvq_s16(vaddq_s16(p2, p3)) * scales[2*j+1];
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[0]), vget_low_s8 (q8bytes.val[2])),
vmull_s8(vget_high_s8(q4bytes.val[0]), vget_high_s8(q8bytes.val[2])));
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q4bytes.val[1]), vget_low_s8 (q8bytes.val[3])),
vmull_s8(vget_high_s8(q4bytes.val[1]), vget_high_s8(q8bytes.val[3])));
float sumf2 = vaddvq_s16(vaddq_s16(p2, p3)) * vgetq_lane_f32(dsc, 2);
#endif
}
sumf += d * (sumi1 + sumi2);
sumf += y[i].d * (sumf1 + sumf2);
}
*s = sumf;
*s = sumf - sum_mins;
#elif defined __AVX2__