k_quants: WIP super-blocks with 64 weights
Slightly more efficient Q3_K and Q5_K
This commit is contained in:
parent
9d27d8d0ea
commit
2ff543c147
1 changed files with 24 additions and 28 deletions
52
k_quants.c
52
k_quants.c
|
@ -1884,7 +1884,7 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
#endif
|
||||
|
||||
const uint8x16_t m3b = vdupq_n_u8(0x3);
|
||||
const uint8x16_t m0 = vdupq_n_u8(1);
|
||||
const uint8x16_t mh = vdupq_n_u8(4);
|
||||
|
||||
int8x16x4_t q3bytes;
|
||||
|
||||
|
@ -1903,20 +1903,15 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
const float d = y[i].d * (float)x[i].d;
|
||||
|
||||
q3h.val[0] = vandq_u8(m0, vcombine_u8(hbits, vshr_n_u8(hbits, 1)));
|
||||
q3h.val[1] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 2), vshr_n_u8(hbits, 3)));
|
||||
q3h.val[2] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 4), vshr_n_u8(hbits, 5)));
|
||||
q3h.val[3] = vandq_u8(m0, vcombine_u8(vshr_n_u8(hbits, 6), vshr_n_u8(hbits, 7)));
|
||||
q3h.val[0] = vandq_u8(mh, vcombine_u8(vshl_n_u8(hbits, 2), vshl_n_u8(hbits, 1)));
|
||||
q3h.val[1] = vandq_u8(mh, vcombine_u8(hbits, vshr_n_u8(hbits, 1)));
|
||||
q3h.val[2] = vandq_u8(mh, vcombine_u8(vshr_n_u8(hbits, 2), vshr_n_u8(hbits, 3)));
|
||||
q3h.val[3] = vandq_u8(mh, vcombine_u8(vshr_n_u8(hbits, 4), vshr_n_u8(hbits, 5)));
|
||||
|
||||
q3h.val[0] = vshlq_n_u8(q3h.val[0], 2);
|
||||
q3h.val[1] = vshlq_n_u8(q3h.val[1], 2);
|
||||
q3h.val[2] = vshlq_n_u8(q3h.val[2], 2);
|
||||
q3h.val[3] = vshlq_n_u8(q3h.val[3], 2);
|
||||
|
||||
q3bytes.val[0] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
|
||||
q3bytes.val[1] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
|
||||
q3bytes.val[2] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
||||
q3bytes.val[3] = vreinterpretq_s8_u8(vaddq_u8(vandq_u8(vshrq_n_u8(q3bits, 6), m3b), q3h.val[3]));
|
||||
q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
|
||||
q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
|
||||
q3bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 4), m3b), q3h.val[2]));
|
||||
q3bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q3bits, 6), q3h.val[3]));
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
isum += vaddvq_s32(vdotq_s32(vzero, q3bytes.val[0], q8bytes.val[0])) * scales[0];
|
||||
|
@ -2690,7 +2685,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
const uint8x16_t m4b = vdupq_n_u8(0xf);
|
||||
const int32x4_t mzero = vdupq_n_s32(0);
|
||||
const uint8x16_t mone = vdupq_n_u8(1);
|
||||
const uint8x16_t mh = vdupq_n_u8(16);
|
||||
|
||||
int8x16x4_t q5bytes;
|
||||
uint8x16x4_t q5h;
|
||||
|
@ -2699,7 +2694,7 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
sumf -= y[i].d * ((float)x[i].d[1] * (y[i].bsums[0] + y[i].bsums[1]) + (float)x[i].d[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
||||
float sumb = -((float)x[i].d[1] * (y[i].bsums[0] + y[i].bsums[1]) + (float)x[i].d[3] * (y[i].bsums[2] + y[i].bsums[3]));
|
||||
|
||||
const uint8_t * restrict q5 = x[i].qs;
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
|
@ -2710,34 +2705,35 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
|
|||
const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
|
||||
const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
|
||||
|
||||
q5h.val[0] = vandq_u8(mone, vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)));
|
||||
q5h.val[1] = vandq_u8(mone, vcombine_u8(vshr_n_u8(qhbits, 2), vshr_n_u8(qhbits, 3)));
|
||||
q5h.val[2] = vandq_u8(mone, vcombine_u8(vshr_n_u8(qhbits, 4), vshr_n_u8(qhbits, 5)));
|
||||
q5h.val[3] = vandq_u8(mone, vcombine_u8(vshr_n_u8(qhbits, 6), vshr_n_u8(qhbits, 7)));
|
||||
q5h.val[0] = vandq_u8(mh, vcombine_u8(vshl_n_u8(qhbits, 4), vshl_n_u8(qhbits, 3)));
|
||||
q5h.val[1] = vandq_u8(mh, vcombine_u8(vshl_n_u8(qhbits, 2), vshl_n_u8(qhbits, 1)));
|
||||
q5h.val[2] = vandq_u8(mh, vcombine_u8(qhbits, vshr_n_u8(qhbits, 1)));
|
||||
q5h.val[3] = vandq_u8(mh, vcombine_u8(vshr_n_u8(qhbits, 2), vshr_n_u8(qhbits, 3)));
|
||||
|
||||
q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), vshlq_n_u8(q5h.val[0], 4)));
|
||||
q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), vshlq_n_u8(q5h.val[1], 4)));
|
||||
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), vshlq_n_u8(q5h.val[2], 4)));
|
||||
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), vshlq_n_u8(q5h.val[3], 4)));
|
||||
q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
|
||||
q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));
|
||||
q5bytes.val[2] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[0], 4), q5h.val[2]));
|
||||
q5bytes.val[3] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5bits.val[1], 4), q5h.val[3]));
|
||||
|
||||
#if defined(__ARM_FEATURE_DOTPROD)
|
||||
|
||||
sumf += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * y[i].d * (float)x[i].d[0];
|
||||
sumf += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * y[i].d * (float)x[i].d[2];
|
||||
sumb += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[0], q8bytes.val[0]), q5bytes.val[1], q8bytes.val[1])) * (float)x[i].d[0];
|
||||
sumb += vaddvq_s32(vdotq_s32(vdotq_s32(mzero, q5bytes.val[2], q8bytes.val[2]), q5bytes.val[3], q8bytes.val[3])) * (float)x[i].d[2];
|
||||
#else
|
||||
|
||||
const int16x8_t p0 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[0]), vget_low_s8 (q8bytes.val[0])),
|
||||
vmull_s8(vget_high_s8(q5bytes.val[0]), vget_high_s8(q8bytes.val[0])));
|
||||
const int16x8_t p1 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[1]), vget_low_s8 (q8bytes.val[1])),
|
||||
vmull_s8(vget_high_s8(q5bytes.val[1]), vget_high_s8(q8bytes.val[1])));
|
||||
sumf += vaddvq_s16(vaddq_s16(p0, p1)) * y[i].d * (float)x[i].d[0];
|
||||
sumb += vaddvq_s16(vaddq_s16(p0, p1)) * (float)x[i].d[0];
|
||||
|
||||
const int16x8_t p2 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[2]), vget_low_s8 (q8bytes.val[2])),
|
||||
vmull_s8(vget_high_s8(q5bytes.val[2]), vget_high_s8(q8bytes.val[2])));
|
||||
const int16x8_t p3 = vaddq_s16(vmull_s8(vget_low_s8 (q5bytes.val[3]), vget_low_s8 (q8bytes.val[3])),
|
||||
vmull_s8(vget_high_s8(q5bytes.val[3]), vget_high_s8(q8bytes.val[3])));
|
||||
sumf += vaddvq_s16(vaddq_s16(p2, p3)) * y[i].d * (float)x[i].d[2];
|
||||
sumb += vaddvq_s16(vaddq_s16(p2, p3)) * (float)x[i].d[2];
|
||||
#endif
|
||||
sumf += y[i].d * sumb;
|
||||
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue