iiq1_m: slightly faster ARM_NEON dot product

10.5 t/s -> 11.65 t/s
This commit is contained in:
Iwan Kawrakow 2024-03-25 12:34:25 +01:00
parent dff85a804b
commit f664692fa8

View file

@ -9765,8 +9765,8 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const int8x8_t minus1 = vdup_n_s8(-1); const int8x8_t minus1 = vdup_n_s8(-1);
const int8x8_t plus1 = vdup_n_s8(+1); const int8x8_t plus1 = vdup_n_s8(+1);
const int32x4_t mask = vdupq_n_s32(0x7); const int32x4_t mask = vdupq_n_s32(0x7);
const int32x4_t mone = vdupq_n_s32(1); const int32x4_t mone = vdupq_n_s32(1);
const int32x4_t mzero = vdupq_n_s32(0); const int32x4_t mzero = vdupq_n_s32(0);
ggml_int8x16x4_t q1b; ggml_int8x16x4_t q1b;
@ -9803,6 +9803,7 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1])); const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3])); const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
const int32x4_t p12 = vpaddq_s32(p1, p2);
delta.val[0] = vcombine_s8(qh[0] & 0x08 ? minus1 : plus1, qh[0] & 0x80 ? minus1 : plus1); delta.val[0] = vcombine_s8(qh[0] & 0x08 ? minus1 : plus1, qh[0] & 0x80 ? minus1 : plus1);
delta.val[1] = vcombine_s8(qh[1] & 0x08 ? minus1 : plus1, qh[1] & 0x80 ? minus1 : plus1); delta.val[1] = vcombine_s8(qh[1] & 0x08 ? minus1 : plus1, qh[1] & 0x80 ? minus1 : plus1);
@ -9811,16 +9812,13 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, delta.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, delta.val[1], q8b.val[1])); const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, delta.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, delta.val[1], q8b.val[1]));
const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, delta.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, delta.val[3], q8b.val[3])); const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, delta.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, delta.val[3], q8b.val[3]));
const int32x4_t p34 = vpaddq_s32(p3, p4);
int32x4_t scale1 = vcombine_s32(vdup_n_s32(sc[ib/2] >> 0), vdup_n_s32(sc[ib/2] >> 3)); int32x4_t scales_4 = {sc[ib/2] >> 0, sc[ib/2] >> 3, sc[ib/2] >> 6, sc[ib/2] >> 9};
int32x4_t scale2 = vcombine_s32(vdup_n_s32(sc[ib/2] >> 6), vdup_n_s32(sc[ib/2] >> 9)); scales_4 = vaddq_s32(vshlq_n_s32(vandq_s32(scales_4, mask), 1), mone);
scale1 = vaddq_s32(vshlq_n_s32(vandq_s32(scale1, mask), 1), mone);
scale2 = vaddq_s32(vshlq_n_s32(vandq_s32(scale2, mask), 1), mone);
sumi1 = vmlaq_s32(sumi1, scale1, p1); sumi1 = vmlaq_s32(sumi1, scales_4, p12);
sumi1 = vmlaq_s32(sumi1, scale2, p2); sumi2 = vmlaq_s32(sumi2, scales_4, p34);
sumi2 = vmlaq_s32(sumi2, scale1, p3);
sumi2 = vmlaq_s32(sumi2, scale2, p4);
qs += 8; qh += 4; qs += 8; qh += 4;