k_quants: WIP super-blocks with 64 weights

Another small improvement for Q3_K and Q5_K on ARM_NEON
This commit is contained in:
Iwan Kawrakow 2023-06-23 08:43:04 +03:00
parent 2ff543c147
commit d92c5a9e29

View file

@ -1903,10 +1903,11 @@ void ggml_vec_dot_q3_K_q8_K(const int n, float * restrict s, const void * restri
const float d = y[i].d * (float)x[i].d; const float d = y[i].d * (float)x[i].d;
q3h.val[0] = vandq_u8(mh, vcombine_u8(vshl_n_u8(hbits, 2), vshl_n_u8(hbits, 1))); const uint8x16_t htmp = vcombine_u8(hbits, vshr_n_u8(hbits, 1));
q3h.val[1] = vandq_u8(mh, vcombine_u8(hbits, vshr_n_u8(hbits, 1))); q3h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
q3h.val[2] = vandq_u8(mh, vcombine_u8(vshr_n_u8(hbits, 2), vshr_n_u8(hbits, 3))); q3h.val[1] = vandq_u8(mh, htmp);
q3h.val[3] = vandq_u8(mh, vcombine_u8(vshr_n_u8(hbits, 4), vshr_n_u8(hbits, 5))); q3h.val[2] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
q3h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 4));
q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0])); q3bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q3bits, m3b), q3h.val[0]));
q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1])); q3bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(vshrq_n_u8(q3bits, 2), m3b), q3h.val[1]));
@ -2705,10 +2706,11 @@ void ggml_vec_dot_q5_K_q8_K(const int n, float * restrict s, const void * restri
const uint8x16x2_t q5bits = vld1q_u8_x2(q5); const uint8x16x2_t q5bits = vld1q_u8_x2(q5);
const int8x16x4_t q8bytes = vld1q_s8_x4(q8); const int8x16x4_t q8bytes = vld1q_s8_x4(q8);
q5h.val[0] = vandq_u8(mh, vcombine_u8(vshl_n_u8(qhbits, 4), vshl_n_u8(qhbits, 3))); const uint8x16_t htmp = vcombine_u8(qhbits, vshr_n_u8(qhbits, 1));
q5h.val[1] = vandq_u8(mh, vcombine_u8(vshl_n_u8(qhbits, 2), vshl_n_u8(qhbits, 1))); q5h.val[0] = vandq_u8(mh, vshlq_n_u8(htmp, 4));
q5h.val[2] = vandq_u8(mh, vcombine_u8(qhbits, vshr_n_u8(qhbits, 1))); q5h.val[1] = vandq_u8(mh, vshlq_n_u8(htmp, 2));
q5h.val[3] = vandq_u8(mh, vcombine_u8(vshr_n_u8(qhbits, 2), vshr_n_u8(qhbits, 3))); q5h.val[2] = vandq_u8(mh, htmp);
q5h.val[3] = vandq_u8(mh, vshrq_n_u8(htmp, 2));
q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0])); q5bytes.val[0] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[0], m4b), q5h.val[0]));
q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1])); q5bytes.val[1] = vreinterpretq_s8_u8(vorrq_u8(vandq_u8(q5bits.val[1], m4b), q5h.val[1]));