iq1_m: ARM_NEON dot product
Works, but very slow (10.5 t/s)
This commit is contained in:
parent
379fdb671b
commit
8009b6d63b
1 changed files with 46 additions and 22 deletions
|
@ -9761,48 +9761,72 @@ void ggml_vec_dot_iq1_m_q8_K (int n, float * restrict s, size_t bs, const void
|
|||
|
||||
const int nb = n / QK_K;
|
||||
|
||||
#if defined z__ARM_NEON
|
||||
// TODO
|
||||
#if defined __ARM_NEON
|
||||
|
||||
const int8x8_t minus1 = vdup_n_s8(-1);
|
||||
const int8x8_t plus1 = vdup_n_s8(+1);
|
||||
const int32x4_t mask = vdupq_n_s32(0x7);
|
||||
const int32x4_t mone = vdupq_n_s32(1);
|
||||
const int32x4_t mzero = vdupq_n_s32(0);
|
||||
|
||||
ggml_int8x16x4_t q1b;
|
||||
ggml_int8x16x4_t q8b;
|
||||
ggml_int8x16x4_t delta;
|
||||
|
||||
iq1m_scale_t scale;
|
||||
|
||||
float sumf = 0;
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
|
||||
const int8_t * q8 = y[i].qs;
|
||||
const uint8_t * qs = x[i].qs;
|
||||
const uint16_t * qh = x[i].qh;
|
||||
const uint8_t * qh = x[i].qh;
|
||||
const uint16_t * sc = (const uint16_t *)x[i].scales;
|
||||
|
||||
int sumi1 = 0, sumi2 = 0, sumi3 = 0;
|
||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||
|
||||
int32x4_t sumi1 = mzero;
|
||||
int32x4_t sumi2 = mzero;
|
||||
|
||||
for (int ib = 0; ib < QK_K/32; ib += 2) {
|
||||
|
||||
q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[ib+0] << 8) & 0x700)))),
|
||||
vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[ib+0] << 5) & 0x700)))));
|
||||
q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[ib+0] << 2) & 0x700)))),
|
||||
vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[ib+0] >> 1) & 0x700)))));
|
||||
q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[ib+1] << 8) & 0x700)))),
|
||||
vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[ib+1] << 5) & 0x700)))));
|
||||
q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[ib+1] << 2) & 0x700)))),
|
||||
vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[ib+1] >> 1) & 0x700)))));
|
||||
qs += 8;
|
||||
q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700)))),
|
||||
vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 4) & 0x700)))));
|
||||
q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | ((qh[1] << 8) & 0x700)))),
|
||||
vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | ((qh[1] << 4) & 0x700)))));
|
||||
q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | ((qh[2] << 8) & 0x700)))),
|
||||
vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | ((qh[2] << 4) & 0x700)))));
|
||||
q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | ((qh[3] << 8) & 0x700)))),
|
||||
vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | ((qh[3] << 4) & 0x700)))));
|
||||
|
||||
q8b = ggml_vld1q_s8_x4(q8); q8 += 64;
|
||||
|
||||
const int32x4_t p1 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[0], q8b.val[0]), q1b.val[1], q8b.val[1]);
|
||||
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q1b.val[2], q8b.val[2]), q1b.val[3], q8b.val[3]);
|
||||
const int32x4_t p1 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, q1b.val[1], q8b.val[1]));
|
||||
const int32x4_t p2 = vpaddq_s32(ggml_vdotq_s32(mzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, q1b.val[3], q8b.val[3]));
|
||||
|
||||
const int ls1 = 2*((qh[ib+0] >> 12) & 7) + 1;
|
||||
const int ls2 = 2*((qh[ib+1] >> 12) & 7) + 1;
|
||||
sumi1 += vaddvq_s32(p1) * ls1;
|
||||
sumi2 += vaddvq_s32(p2) * ls2;
|
||||
sumi3 += (y[i].bsums[2*ib+0] + y[i].bsums[2*ib+1]) * ls1 * (qh[ib+0] & 0x8000 ? -1 : 1)
|
||||
+ (y[i].bsums[2*ib+2] + y[i].bsums[2*ib+3]) * ls2 * (qh[ib+1] & 0x8000 ? -1 : 1);
|
||||
delta.val[0] = vcombine_s8(qh[0] & 0x08 ? minus1 : plus1, qh[0] & 0x80 ? minus1 : plus1);
|
||||
delta.val[1] = vcombine_s8(qh[1] & 0x08 ? minus1 : plus1, qh[1] & 0x80 ? minus1 : plus1);
|
||||
delta.val[2] = vcombine_s8(qh[2] & 0x08 ? minus1 : plus1, qh[2] & 0x80 ? minus1 : plus1);
|
||||
delta.val[3] = vcombine_s8(qh[3] & 0x08 ? minus1 : plus1, qh[3] & 0x80 ? minus1 : plus1);
|
||||
|
||||
const int32x4_t p3 = vpaddq_s32(ggml_vdotq_s32(mzero, delta.val[0], q8b.val[0]), ggml_vdotq_s32(mzero, delta.val[1], q8b.val[1]));
|
||||
const int32x4_t p4 = vpaddq_s32(ggml_vdotq_s32(mzero, delta.val[2], q8b.val[2]), ggml_vdotq_s32(mzero, delta.val[3], q8b.val[3]));
|
||||
|
||||
int32x4_t scale1 = vcombine_s32(vdup_n_s32(sc[ib/2] >> 0), vdup_n_s32(sc[ib/2] >> 3));
|
||||
int32x4_t scale2 = vcombine_s32(vdup_n_s32(sc[ib/2] >> 6), vdup_n_s32(sc[ib/2] >> 9));
|
||||
scale1 = vaddq_s32(vshlq_n_s32(vandq_s32(scale1, mask), 1), mone);
|
||||
scale2 = vaddq_s32(vshlq_n_s32(vandq_s32(scale2, mask), 1), mone);
|
||||
|
||||
sumi1 = vmlaq_s32(sumi1, scale1, p1);
|
||||
sumi1 = vmlaq_s32(sumi1, scale2, p2);
|
||||
sumi2 = vmlaq_s32(sumi2, scale1, p3);
|
||||
sumi2 = vmlaq_s32(sumi2, scale2, p4);
|
||||
|
||||
qs += 8; qh += 4;
|
||||
|
||||
}
|
||||
|
||||
sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2 + IQ1S_DELTA * sumi3);
|
||||
sumf += y[i].d * GGML_FP16_TO_FP32(scale.fp16) * (vaddvq_s32(sumi1) + IQ1M_DELTA * vaddvq_s32(sumi2));
|
||||
}
|
||||
|
||||
*s = sumf;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue