diff --git a/ggml-metal.metal b/ggml-metal.metal index 1615f8cea..69a928c24 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2546,7 +2546,12 @@ typedef struct { uint8_t signs[QK_K/8]; uint8_t scales[IQ3S_N_SCALE]; } block_iq3_s; +#ifdef IQ3S_SLOW_MULT #define IQ3S_MULTIPLIER 190842953 +#else +//#define IQ3S_MULTIPLIER 898886 +#define IQ3S_MULTIPLIER 842866 +#endif typedef struct { half d; @@ -4691,15 +4696,21 @@ void kernel_mul_mv_iq3_s_f32_impl( threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; { - uint32_t aux32; - thread int8_t * q = (thread int8_t *)&aux32; int nval = 8; int pos = (32*sgitg + tiisg)*nval; +#ifdef IQ3S_SLOW_MULT + uint32_t aux32; + thread int8_t * q = (thread int8_t *)&aux32; for (int i = 0; i < nval; ++i) { aux32 = (IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f; for (int k = 0; k < 4; ++k) q[k] = 2*((q[k]-1)/2) + 1; values[pos + i] = aux32; } +#else + for (int i = 0; i < nval; ++i) { + values[pos + i] = ((IQ3S_MULTIPLIER * (pos + i)) & 0x0f0f0f0f) | 0x01010101; + } +#endif threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4733,17 +4744,16 @@ void kernel_mul_mv_iq3_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - //aux32[0] = (IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f; - //aux32[1] = (IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f; - //threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); - //threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); - threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + qs[2*l+0] + - select(0, 256, qh[0] & kmask_iq2xs[2*l+0])); - threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + qs[2*l+1] + - select(0, 256, qh[0] & kmask_iq2xs[2*l+1])); + // This is slower than pre-computing the grid in shared memory and loading from there + //aux32[0] = ((IQ3S_MULTIPLIER * (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; + //aux32[1] = ((IQ3S_MULTIPLIER * (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))) & 0x0f0f0f0f) | 0x01010101; + //for (int j = 0; j < 4; ++j) { + // sum[0] += yl[8*l + j + 0] * grid[j+0] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); + // sum[1] += yl[8*l + j + 4] * grid[j+4] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); + //} + threadgroup const uint8_t * grid1 = (threadgroup const uint8_t *)(values + (qs[2*l+0] | ((qh[0] << (8-2*l)) & 256))); + threadgroup const uint8_t * grid2 = (threadgroup const uint8_t *)(values + (qs[2*l+1] | ((qh[0] << (7-2*l)) & 256))); for (int j = 0; j < 4; ++j) { - //sum[0] += yl[8*l + j + 0] * (2*((grid[j+0] - 1)/2) + 1) * select(1, -1, signs[l] & kmask_iq2xs[j+0]); - //sum[1] += yl[8*l + j + 4] * (2*((grid[j+4] - 1)/2) + 1) * select(1, -1, signs[l] & kmask_iq2xs[j+4]); sum[0] += yl[8*l + j + 0] * grid1[j] * select(1, -1, signs[l] & kmask_iq2xs[j+0]); sum[1] += yl[8*l + j + 4] * grid2[j] * select(1, -1, signs[l] & kmask_iq2xs[j+4]); } @@ -5657,6 +5667,7 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)); uint32_t aux32[2]; thread const int8_t * grid = (thread const int8_t *)aux32; +#ifdef IQ3S_SLOW)MULT aux32[0] = (IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f; aux32[1] = (IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f; for (int i = 0; i < 4; ++i) { @@ -5669,6 +5680,20 @@ void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg[2][i] = dl * (2*((grid[i+0]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+0]); reg[3][i] = dl * (2*((grid[i+4]-1)/2)+1) * select(1, -1, signs[1] & kmask_iq2xs[i+4]); } +#else + aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+0] | ((qh << 8) & 256))) & 0x0f0f0f0f) | 0x01010101; + aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+1] | ((qh << 7) & 256))) & 0x0f0f0f0f) | 0x01010101; + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * grid[i+0] * select(1, -1, signs[0] & kmask_iq2xs[i+0]); + reg[1][i] = dl * grid[i+4] * select(1, -1, signs[0] & kmask_iq2xs[i+4]); + } + aux32[0] = ((IQ3S_MULTIPLIER * (qs[4*il+2] | ((qh << 6) & 256))) & 0x0f0f0f0f) | 0x01010101; + aux32[1] = ((IQ3S_MULTIPLIER * (qs[4*il+3] | ((qh << 5) & 256))) & 0x0f0f0f0f) | 0x01010101; + for (int i = 0; i < 4; ++i) { + reg[2][i] = dl * grid[i+0] * select(1, -1, signs[1] & kmask_iq2xs[i+0]); + reg[3][i] = dl * grid[i+4] * select(1, -1, signs[1] & kmask_iq2xs[i+4]); + } +#endif } template diff --git a/ggml-quants.c b/ggml-quants.c index 7ce2f077a..e4478102f 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -10070,6 +10070,7 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v vmovl_u8(vget_low_u8(idx_l))); const uint16x8_t idx_2 = vorrq_u16(vandq_u16(vshlq_u16(vdupq_n_u16(qh[ib32+1]), idx_shift), idx_mask1), vmovl_u8(vget_high_u8(idx_l))); +#ifdef IQ3S_SLOW_MULT q3s.val[0] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)); q3s.val[1] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)); q3s.val[2] = vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)); @@ -10078,6 +10079,12 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v q3s.val[1] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[1], m1), m0), 1), 1), m1); q3s.val[2] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[2], m1), m0), 1), 1), m1); q3s.val[3] = vorrq_s8(vshlq_n_s8(vshrq_n_u8(vmaxq_s8(vsubq_s8(q3s.val[3], m1), m0), 1), 1), m1); +#else + q3s.val[0] = vorrq_s8(vreinterpretq_s8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_1))), idx_mask2)), m1); + q3s.val[1] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_1))), idx_mask2)), m1); + q3s.val[2] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_low_u16 (idx_2))), idx_mask2)), m1); + q3s.val[3] = vorrq_s8(vreinterpretq_u8_u32(vandq_u32(vmulq_u32(idx_mult, vmovl_u16(vget_high_u16(idx_2))), idx_mask2)), m1); +#endif vs.val[0] = vreinterpretq_u8_u32(vdupq_n_u32(signs[0] | (signs[1] << 16))); vs.val[1] = vandq_u8(ggml_vqtbl1q_u8(vs.val[0], mask1.val[1]), mask2); @@ -10094,8 +10101,6 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v vs.val[0] = vorrq_u8(vceqq_u8(vs.val[0], mask2), vreinterpretq_u8_s8(m1)); vs.val[1] = vorrq_u8(vceqq_u8(vs.val[1], mask2), vreinterpretq_u8_s8(m1)); - signs += 4; - q3s.val[2] = vmulq_s8(vreinterpretq_s8_u8(vs.val[0]), q3s.val[2]); q3s.val[3] = vmulq_s8(vreinterpretq_s8_u8(vs.val[1]), q3s.val[3]); @@ -10103,6 +10108,8 @@ void ggml_vec_dot_iq3_s_q8_K (int n, float * GGML_RESTRICT s, size_t bs, const v const int32x4_t p2 = ggml_vdotq_s32(ggml_vdotq_s32(vdupq_n_s32(0), q3s.val[2], q8b.val[2]), q3s.val[3], q8b.val[3]); sumi1 += vaddvq_s32(p1) * (1 + 2*(x[i].scales[ib32/2] & 0xf)); sumi2 += vaddvq_s32(p2) * (1 + 2*(x[i].scales[ib32/2] >> 4)); + + signs += 4; } sumf += d*(sumi1 + sumi2); }