iq3s_mult: ARM and Metal
This commit is contained in:
parent
b6402fa757
commit
5b9c8785fa
2 changed files with 46 additions and 14 deletions
|
@ -10070,6 +10070,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
vmovl_u8(vget_low_u8(idx_l)));
|
||||
const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1),
|
||||
vmovl_u8(vget_high_u8(idx_l)));
|
||||
#ifdef IQ3S_SLOW_MULT
|
||||
q3s.val[0] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2));
|
||||
q3s.val[1] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2));
|
||||
q3s.val[2] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2));
|
||||
|
@ -10078,6 +10079,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
q3s.val[1] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[1], m1), m0), 1), 1), m1);
|
||||
q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1);
|
||||
q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1);
|
||||
#else
|
||||
q3s.val[0] = vorrq_s8(vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)), m1);
|
||||
q3s.val[1] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)), m1);
|
||||
q3s.val[2] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)), m1);
|
||||
q3s.val[3] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)), m1);
|
||||
#endif
|
||||
|
||||
vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16)));
|
||||
vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2);
|
||||
|
@ -10094,8 +10101,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1));
|
||||
vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1));
|
||||
|
||||
signs += 4;
|
||||
|
||||
q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[2]);
|
||||
q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[3]);
|
||||
|
||||
|
@ -10103,6 +10108,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v
|
|||
const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]);
|
||||
sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf));
|
||||
sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4));
|
||||
|
||||
signs += 4;
|
||||
}
|
||||
sumf += d*(sumi1 + sumi2);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue