k_quants: WIP super-blocks with 64 weights
Q3_K working on ARM_NEON, but quite a bit slower than 256 weights.
This commit is contained in:
parent
80c75fe821
commit
2b2a13c4f9
1 changed files with 34 additions and 88 deletions
122
k_quants.c
122
k_quants.c
|
@ -1877,21 +1877,14 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
const int nb = n / QK_K;
|
const int nb = n / QK_K;
|
||||||
|
|
||||||
#ifdef z__ARM_NEON
|
#ifdef __ARM_NEON
|
||||||
|
|
||||||
uint32_t aux[3];
|
|
||||||
uint32_t utmp[4];
|
|
||||||
|
|
||||||
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
|
||||||
#ifdef __ARM_FEATURE_DOTPROD
|
#ifdef __ARM_FEATURE_DOTPROD
|
||||||
const int32x4_t vzero = vdupq_n_s32(0);
|
const int32x4_t vzero = vdupq_n_s32(0);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
const uint8x16_t m0 = vdupq_n_u8(1);
|
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
||||||
const uint8x16_t m1 = vshlq_n_u8(m0, 1);
|
const uint8x16_t m0 = vdupq_n_u8(1);
|
||||||
const uint8x16_t m2 = vshlq_n_u8(m0, 2);
|
|
||||||
const uint8x16_t m3 = vshlq_n_u8(m0, 3);
|
|
||||||
const int8_t m32 = 32;
|
|
||||||
|
|
||||||
int8x16x4_t q3bytes;
|
int8x16x4_t q3bytes;
|
||||||
|
|
||||||
|
@ -1899,96 +1892,49 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
||||||
|
|
||||||
for (int i = 0; i < nb; ++i) {
|
for (int i = 0; i < nb; ++i) {
|
||||||
|
|
||||||
const float d = y[i].d * ggml_fp16_to_fp32(x[i].d);
|
|
||||||
|
|
||||||
const uint8_t * restrict q3 = x[i].qs;
|
|
||||||
const uint8_t * restrict qh = x[i].hmask;
|
|
||||||
const int8_t * restrict q8 = y[i].qs;
|
|
||||||
|
|
||||||
uint8x16x2_t qhbits = vld1q_u8_x2(qh);
|
|
||||||
|
|
||||||
uint8x16x4_t q3h;
|
uint8x16x4_t q3h;
|
||||||
|
|
||||||
int32_t isum = 0;
|
const uint8x8_t hbits = vld1_u8(x[i].hmask);
|
||||||
|
const uint8x16_t q3bits = vld1q_u8(x[i].qs);
|
||||||
|
const int8x16x4_t q8bytes = vld1q_s8_x4(y[i].qs);
|
||||||
|
|
||||||
// Set up scales
|
const int8_t * restrict scales = x[i].scales;
|
||||||
memcpy(aux, x[i].scales, 12);
|
int32_t isum = -4*(scales[0] * y[i].bsums[0] + scales[1] * y[i].bsums[1] + scales[2] * y[i].bsums[2] + scales[3] * y[i].bsums[3]);
|
||||||
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
|
|
||||||
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
|
|
||||||
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
|
|
||||||
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
|
|
||||||
|
|
||||||
int8_t * scale = (int8_t *)utmp;
|
const float d = y[i].d * (float)x[i].d;
|
||||||
for (int j = 0; j < 16; ++j) scale[j] -= m32;
|
|
||||||
|
|
||||||
for (int j = 0; j < QK_K/128; ++j) {
|
q3h.val[0] = vandq_u8(m0, vcombine_u8(hbits, vshr_n_u8(hbits, 1)));
|
||||||
|
q3h.val[1] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 2), vshr_n_u8(hbits, 3)));
|
||||||
|
q3h.val[2] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 4), vshr_n_u8(hbits, 5)));
|
||||||
|
q3h.val[3] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 6), vshr_n_u8(hbits, 7)));
|
||||||
|
|
||||||
const uint8x16x2_t q3bits = vld1q_u8_x2(q3); q3 += 32;
|
q3h.val[0] = vshlq_n_u8(q3h.val[0], 2);
|
||||||
const int8x16x4_t q8bytes_1 = vld1q_s8_x4(q8); q8 += 64;
|
q3h.val[1] = vshlq_n_u8(q3h.val[1], 2);
|
||||||
const int8x16x4_t q8bytes_2 = vld1q_s8_x4(q8); q8 += 64;
|
q3h.val[2] = vshlq_n_u8(q3h.val[2], 2);
|
||||||
|
q3h.val[3] = vshlq_n_u8(q3h.val[3], 2);
|
||||||
|
|
||||||
q3h.val[0] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[0]), 2);
|
q3bytes.val[0] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
|
||||||
q3h.val[1] = vshlq_n_u8(vbicq_u8(m0, qhbits.val[1]), 2);
|
q3bytes.val[1] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
|
||||||
q3h.val[2] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[0]), 1);
|
q3bytes.val[2] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
||||||
q3h.val[3] = vshlq_n_u8(vbicq_u8(m1, qhbits.val[1]), 1);
|
q3bytes.val[3] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 6), m3b), q3h.val[3]));
|
||||||
|
|
||||||
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[0], m3b)), vreinterpretq_s8_u8(q3h.val[0]));
|
|
||||||
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(q3bits.val[1], m3b)), vreinterpretq_s8_u8(q3h.val[1]));
|
|
||||||
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 2), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
|
||||||
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 2), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_DOTPROD)
|
#if defined(__ARM_FEATURE_DOTPROD)
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_1.val[0])) * scale[0];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_1.val[1])) * scale[1];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes.val[1])) * scales[1];
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_1.val[2])) * scale[2];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes.val[2])) * scales[2];
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_1.val[3])) * scale[3];
|
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes.val[3])) * scales[3];
|
||||||
#else
|
#else
|
||||||
int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_1.val[0])),
|
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||||
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_1.val[0])));
|
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||||
int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_1.val[1])),
|
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||||
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_1.val[1])));
|
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||||
int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_1.val[2])),
|
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||||
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_1.val[2])));
|
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||||
int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_1.val[3])),
|
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||||
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_1.val[3])));
|
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||||
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
isum += vaddvq_s16(p0) * scales[0] + vaddvq_s16(p1) * scales[1] + vaddvq_s16(p2) * scales[2] + vaddvq_s16(p3) * scales[3];
|
||||||
#endif
|
#endif
|
||||||
scale += 4;
|
|
||||||
|
|
||||||
q3h.val[0] = vbicq_u8(m2, qhbits.val[0]);
|
|
||||||
q3h.val[1] = vbicq_u8(m2, qhbits.val[1]);
|
|
||||||
q3h.val[2] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[0]), 1);
|
|
||||||
q3h.val[3] = vshrq_n_u8(vbicq_u8(m3, qhbits.val[1]), 1);
|
|
||||||
|
|
||||||
q3bytes.val[0] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 4), m3b)), vreinterpretq_s8_u8(q3h.val[0]));
|
|
||||||
q3bytes.val[1] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 4), m3b)), vreinterpretq_s8_u8(q3h.val[1]));
|
|
||||||
q3bytes.val[2] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[0], 6), m3b)), vreinterpretq_s8_u8(q3h.val[2]));
|
|
||||||
q3bytes.val[3] = vsubq_s8(vreinterpretq_s8_u8(vandq_u8(vshrq_n_u8(q3bits.val[1], 6), m3b)), vreinterpretq_s8_u8(q3h.val[3]));
|
|
||||||
|
|
||||||
#if defined(__ARM_FEATURE_DOTPROD)
|
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes_2.val[0])) * scale[0];
|
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[1], q8bytes_2.val[1])) * scale[1];
|
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[2], q8bytes_2.val[2])) * scale[2];
|
|
||||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[3], q8bytes_2.val[3])) * scale[3];
|
|
||||||
#else
|
|
||||||
p0 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[0]), vget_low_s8 (q8bytes_2.val[0])),
|
|
||||||
vmull_s8(vget_high_s8(q3bytes.val[0]), vget_high_s8(q8bytes_2.val[0])));
|
|
||||||
p1 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[1]), vget_low_s8 (q8bytes_2.val[1])),
|
|
||||||
vmull_s8(vget_high_s8(q3bytes.val[1]), vget_high_s8(q8bytes_2.val[1])));
|
|
||||||
p2 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[2]), vget_low_s8 (q8bytes_2.val[2])),
|
|
||||||
vmull_s8(vget_high_s8(q3bytes.val[2]), vget_high_s8(q8bytes_2.val[2])));
|
|
||||||
p3 = vaddq_s16(vmull_s8(vget_low_s8 (q3bytes.val[3]), vget_low_s8 (q8bytes_2.val[3])),
|
|
||||||
vmull_s8(vget_high_s8(q3bytes.val[3]), vget_high_s8(q8bytes_2.val[3])));
|
|
||||||
isum += vaddvq_s16(p0) * scale[0] + vaddvq_s16(p1) * scale[1] + vaddvq_s16(p2) * scale[2] + vaddvq_s16(p3) * scale[3];
|
|
||||||
#endif
|
|
||||||
scale += 4;
|
|
||||||
|
|
||||||
if (j == 0) {
|
|
||||||
qhbits.val[0] = vshrq_n_u8(qhbits.val[0], 4);
|
|
||||||
qhbits.val[1] = vshrq_n_u8(qhbits.val[1], 4);
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
sum += d * isum;
|
sum += d * isum;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue