iq3s_mult: ARM and Metal

This commit is contained in:
Iwan Kawrakow 2024-03-03 11:30:01 +02:00
parent b6402fa757
commit 5b9c8785fa
2 changed files with 46 additions and 14 deletions

View file

@ -10070,6 +10070,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
vmovl_u8(vget_low_u8(idx_l)));
const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1),
vmovl_u8(vget_high_u8(idx_l)));
#ifdef IQ3S_SLOW_MULT
q3s.val[0] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2));
q3s.val[1] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2));
q3s.val[2] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2));
@ -10078,6 +10079,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
q3s.val[1] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[1], m1), m0), 1), 1), m1);
q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1);
q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1);
#else
q3s.val[0] = vorrq_s8(vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)), m1);
q3s.val[1] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)), m1);
q3s.val[2] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)), m1);
q3s.val[3] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)), m1);
#endif
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
@ -10094,8 +10101,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1));
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1));
signs += 4;
q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[2]);
q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[3]);
@ -10103,6 +10108,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
signs += 4;
}
sumf += d*(sumi1 + sumi2);
}