k_quants: WIP super-blocks with 64 weights
Yet another speedup for Q5_K on ARM_NEON. We are now within 10% of the QK_K = 256 version.
This commit is contained in:
parent
d92c5a9e29
commit
fae24afd01
1 changed files with 22 additions and 7 deletions
29
k_quants.c
29
k_quants.c
|
@ -2693,9 +2693,19 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
|
||||||
|
float32x4_t acc1 = vdupq_n_f32(0.f);
|
||||||
|
float32x4_t acc2 = vdupq_n_f32(0.f);
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
float sumb = -((float)x[i].d[1] * (y[i].bsums[0] + y[i].bsums[1]) + (float)x[i].d[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
const float16x4_t s16 = vld1_f16(x[i].d);
|
||||||
|
const float32x4_t s32 = vmulq_n_f32(vcvt_f32_f16(s16), y[i].d);
|
||||||
|
//const int16x4_t bi16 = vld1_s16(y[i].bsums);
|
||||||
|
//const int32x4_t bi32 = vmovl_s16(vpadd_s16(bi16, bi16));
|
||||||
|
//const float32x4_t bf32 = vcvtq_f32_s32(bi32);
|
||||||
|
//sumf -= (vgetq_lane_f32(s32, 1) * vgetq_lane_f32(bf32, 0) + vgetq_lane_f32(s32, 3) * vgetq_lane_f32(bf32, 1));
|
||||||
|
// The above is slightly slower than just this:
|
||||||
|
sumf -= (vgetq_lane_f32(s32, 1) * (y[i].bsums[0] + y[i].bsums[1]) + vgetq_lane_f32(s32, 3) * (y[i].bsums[2] + y[i].bsums[3]));
|
||||||
|
|
||||||
const uint8_t * restrict q5 = x[i].qs;
|
const uint8_t * restrict q5 = x[i].qs;
|
||||||
const uint8_t * restrict qh = x[i].qh;
|
const uint8_t * restrict qh = x[i].qh;
|
||||||
|
@ -2719,27 +2729,32 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_DOTPROD)
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
|
|
||||||
sumb += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * (float)x[i].d[0];
|
acc1 = vmlaq_n_f32(acc1, vcvtq_f32_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])),
|
||||||
sumb += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * (float)x[i].d[2];
|
vgetq_lane_f32(s32, 0));
|
||||||
|
acc2 = vmlaq_n_f32(acc2, vcvtq_f32_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])),
|
||||||
|
vgetq_lane_f32(s32, 2));
|
||||||
#else
|
#else
|
||||||
|
|
||||||
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
sumb += vaddvq_s16(vaddq_s16(p0, p1)) * (float)x[i].d[0];
|
const int16x8_t p01_16 = vaddq_s16(p0, p1);
|
||||||
|
const int32x4_t p01_32 = vaddq_s32(vmovl_s16(vget_low_s16(p01_16)), vmovl_s16(vget_high_s16(p01_16)));
|
||||||
|
acc1 = vmlaq_n_f32(acc1, vcvtq_f32_s32(p01_32), vgetq_lane_f32(s32, 0));
|
||||||
|
|
||||||
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
sumb += vaddvq_s16(vaddq_s16(p2, p3)) * (float)x[i].d[2];
|
const int16x8_t p02_16 = vaddq_s16(p2, p3);
|
||||||
|
const int32x4_t p02_32 = vaddq_s32(vmovl_s16(vget_low_s16(p02_16)), vmovl_s16(vget_high_s16(p02_16)));
|
||||||
|
acc2 = vmlaq_n_f32(acc2, vcvtq_f32_s32(p02_32), vgetq_lane_f32(s32, 2));
|
||||||
#endif
|
#endif
|
||||||
sumf += y[i].d * sumb;
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*s = sumf;
|
*s = vaddvq_f32(vaddq_f32(acc1, acc2)) + sumf;
|
||||||
|
|
||||||
#elif defined __AVX2__
|
#elif defined __AVX2__
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue