iq3_s_mult_shuffle: works on ARM_NEON and Metal

This commit is contained in:
Iwan Kawrakow 2024-03-04 20:10:36 +02:00
parent b587482287
commit a6a263b919
2 changed files with 21 additions and 16 deletions

View file

@ -10005,6 +10005,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
const uint8x16x2_t mask1 = vld1q_u8_x2(k_mask1);
const uint8x16_t mask2 = vld1q_u8(k_mask2);
const uint8x16_t shuff = vld1q_u8(iq3s_values);
const uint32x4_t idx_mult = vdupq_n_u32(IQ3S_MULTIPLIER);
const int16x8_t idx_shift = vld1q_s16(k_shift);
@ -10042,10 +10043,10 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1);
q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1);
#else
q3s.val[0] = vorrq_s8(vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)), m1);
q3s.val[1] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)), m1);
q3s.val[2] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)), m1);
q3s.val[3] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)), m1);
q3s.val[0] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)));
q3s.val[1] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)));
q3s.val[2] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)));
q3s.val[3] = vqtbl1q_s8(shuff, vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)));
#endif
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));