diff --git a/ggml-metal.metal b/ggml-metal.metal index 50185ae4d..76999327f 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -2595,8 +2595,8 @@ typedef struct { typedef struct { half d; - uint8_t qs[QK_K/8]; - uint8_t scales[QK_K/16]; + uint8_t qs[QK_K/8]; + uint16_t qh[QK_K/32]; } block_iq1_s; // Non-linear quants @@ -4358,25 +4358,24 @@ void kernel_mul_mv_iq1_s_f32_impl( const int ib = ib32 % (QK_K / 32); device const block_iq1_s * xr = x + ibl; - device const uint8_t * qs = xr->qs + 4 * ib + 2 * il; - device const uint8_t * sc = xr->scales + 2 * ib + il; - device const half * dh = &xr->d; + device const uint8_t * qs = xr->qs + 4 * ib + 2 * il; + device const uint16_t * qh = xr->qh + ib; + device const half * dh = &xr->d; for (int row = 0; row < N_DST; row++) { - constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); - constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); + constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | (((qh[0] >> (6*il+0)) & 7) << 8))); + constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | (((qh[0] >> (6*il+3)) & 7) << 8))); - float2 sum = {0}; + float sum = 0; for (int j = 0; j < 8; ++j) { - sum[0] += yl[j+ 0] * grid1[j]; - sum[1] += yl[j+ 8] * grid2[j]; + sum += yl[j+0] * grid1[j] + yl[j+8] * grid2[j]; } - sumf[row] += (float)dh[0] * (sum[0] * (2*(sc[0] & 7) + 1) + sum[1] * (2*((sc[0] >> 4) & 7) + 1)); + sumf[row] += (float)dh[0] * sum * (2*(qh[0] >> 12) + 1); dh += nb*sizeof(block_iq1_s)/2; qs += nb*sizeof(block_iq1_s); - sc += nb*sizeof(block_iq1_s); + qh += nb*sizeof(block_iq1_s)/2; } y4 += 16 * 32; @@ -5066,16 +5065,17 @@ void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & template void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) { // il is 0...15 for QK_K = 256 => index of block of 32 is il/2 + const int ib32 = il/2; + il = il%2; const float d = xb->d; - device const uint8_t * qs = xb->qs + 2*il; - device const uint8_t * sc = xb->scales + il; - const float dl1 = d * (2*(sc[0] & 7) + 1); - const float dl2 = d * (2*((sc[0] >> 4) & 7) + 1); - constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((sc[0] & 0x08) << 5))); - constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((sc[0] & 0x80) << 1))); + device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; + device const uint16_t * qh = xb->qh; + const float dl = d * (2*(qh[ib32] >> 12) + 1); + constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | (((qh[ib32] >> (6*il+0)) & 7) << 8))); + constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | (((qh[ib32] >> (6*il+3)) & 7) << 8))); for (int i = 0; i < 8; ++i) { - reg[i/4+0][i%4] = dl1 * grid1[i]; - reg[i/4+2][i%4] = dl2 * grid2[i]; + reg[i/4+0][i%4] = dl * grid1[i]; + reg[i/4+2][i%4] = dl * grid2[i]; } } diff --git a/ggml-quants.c b/ggml-quants.c index 90579ab7c..b1bd68ea3 100644 --- a/ggml-quants.c +++ b/ggml-quants.c @@ -9570,60 +9570,43 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void const int nb = n / QK_K; - // TODO: implement for QK_K = 64 -#if defined __ARM_NEON && QK_K == 256 +#if defined __ARM_NEON - const uint8x16_t m8 = vdupq_n_u8(0x08); - const uint8x16_t m7 = vdupq_n_u8(0x07); - const uint8x16_t m1 = vdupq_n_u8(0x01); - const int32x4_t vzero = vdupq_n_s32(0); - - uint16_t gindex[8]; - uint16x8x2_t vindex; - int8x16x4_t q1b; + ggml_int8x16x4_t q1b; ggml_int8x16x4_t q8b; - uint16x8x4_t scales; - int32x4x2_t sumi; - int32x4x2_t dotq; float sumf = 0; for (int i = 0; i < nb; ++i) { - const int8_t * q8 = y[i].qs; - const uint8_t * qs = x[i].qs; - const uint8_t * sc = x[i].scales; + const int8_t * q8 = y[i].qs; + const uint8_t * qs = x[i].qs; + const uint16_t * qh = x[i].qh; - sumi.val[0] = sumi.val[1] = vzero; + int sumi1 = 0, sumi2 = 0; - for (int i128 = 0; i128 < QK_K/128; ++i128) { - const uint8x16_t ql = vld1q_u8(qs); qs += 16; - const uint8x8_t tm1 = vld1_u8 (sc); sc += 8; - const uint8x8_t tm2 = vshr_n_u8(tm1, 4); - const uint8x16_t qh = vcombine_u8(vzip1_u8(tm1, tm2), vzip2_u8(tm1, tm2)); - const uint8x16_t hbit = vandq_u8(qh, m8); - vindex.val[0] = vorrq_u16(vmovl_u8(vget_low_u8 (ql)), vshlq_n_u16(vmovl_u8(vget_low_u8 (hbit)), 5)); - vindex.val[1] = vorrq_u16(vmovl_u8(vget_high_u8(ql)), vshlq_n_u16(vmovl_u8(vget_high_u8(hbit)), 5)); - const uint8x16_t scales8 = vorrq_u8(vshlq_n_u8(vandq_u8(qh, m7), 1), m1); - scales.val[0] = vmovl_u8(vget_low_u8 (scales8)); - scales.val[1] = vmovl_u8(vget_high_u8 (scales8)); + for (int ib = 0; ib < QK_K/32; ib += 2) { - for (int l = 0; l < 2; ++l) { - vst1q_u16(gindex+0, vindex.val[l]); - q1b.val[0] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[0])), vld1_s8((const void *)(iq1s_grid+gindex[1]))); - q1b.val[1] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[2])), vld1_s8((const void *)(iq1s_grid+gindex[3]))); - q1b.val[2] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[4])), vld1_s8((const void *)(iq1s_grid+gindex[5]))); - q1b.val[3] = vcombine_s8(vld1_s8((const void *)(iq1s_grid+gindex[6])), vld1_s8((const void *)(iq1s_grid+gindex[7]))); - q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + q1b.val[0] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[0] | (((qh[ib+0] >> 0) & 7) << 8)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[1] | (((qh[ib+0] >> 3) & 7) << 8))))); + q1b.val[1] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[2] | (((qh[ib+0] >> 6) & 7) << 8)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[3] | (((qh[ib+0] >> 9) & 7) << 8))))); + q1b.val[2] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[4] | (((qh[ib+1] >> 0) & 7) << 8)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[5] | (((qh[ib+1] >> 3) & 7) << 8))))); + q1b.val[3] = vcombine_s8(vld1_s8((const int8_t *)(iq1s_grid + (qs[6] | (((qh[ib+1] >> 6) & 7) << 8)))), + vld1_s8((const int8_t *)(iq1s_grid + (qs[7] | (((qh[ib+1] >> 9) & 7) << 8))))); + qs += 8; - dotq.val[0] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[0], q8b.val[0]), ggml_vdotq_s32(vzero, q1b.val[1], q8b.val[1])); - dotq.val[1] = vpaddq_s32(ggml_vdotq_s32(vzero, q1b.val[2], q8b.val[2]), ggml_vdotq_s32(vzero, q1b.val[3], q8b.val[3])); + q8b = ggml_vld1q_s8_x4(q8); q8 += 64; + + const int32x4_t p1 = ggml_vdotq_s32(q1b.val[0], q8b.val[0], ggml_vdotq_s32(q1b.val[1], q8b.val[1], vdupq_n_s32(0))); + const int32x4_t p2 = ggml_vdotq_s32(q1b.val[2], q8b.val[2], ggml_vdotq_s32(q1b.val[3], q8b.val[3], vdupq_n_s32(0))); + + sumi1 += vaddvq_s32(p1) * (2*(qh[ib+0] >> 12) + 1); + sumi2 += vaddvq_s32(p2) * (2*(qh[ib+1] >> 12) + 1); - sumi.val[0] = vmlaq_s32(sumi.val[0], dotq.val[0], vreinterpretq_s32_u32(vmovl_u16(vget_low_u16 (scales.val[l])))); - sumi.val[1] = vmlaq_s32(sumi.val[1], dotq.val[1], vreinterpretq_s32_u32(vmovl_u16(vget_high_u16(scales.val[l])))); - } } - sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * vaddvq_s32(vaddq_s32(sumi.val[0], sumi.val[1])); + sumf += y[i].d * GGML_FP16_TO_FP32(x[i].d) * (sumi1 + sumi2); } *s = sumf; @@ -9640,9 +9623,9 @@ void ggml_vec_dot_iq1_s_q8_K (int n, float * restrict s, size_t bs, const void __m256i sumi = _mm256_setzero_si256(); for (int ib = 0; ib < QK_K/32; ib += 2) { const __m256i q1b_1 = _mm256_set_epi64x(iq1s_grid[qs[3] | (((qh[ib+0] >> 9) & 7) << 8)], iq1s_grid[qs[2] | (((qh[ib+0] >> 6) & 7) << 8)], - iq1s_grid[qs[1] | (((qh[ib+0] >> 3) & 7) << 8)], iq1s_grid[qs[0] | (((qh[ib+0] >> 0) & 7) << 8)]); + iq1s_grid[qs[1] | (((qh[ib+0] >> 3) & 7) << 8)], iq1s_grid[qs[0] | (((qh[ib+0] >> 0) & 7) << 8)]); const __m256i q1b_2 = _mm256_set_epi64x(iq1s_grid[qs[7] | (((qh[ib+1] >> 9) & 7) << 8)], iq1s_grid[qs[6] | (((qh[ib+1] >> 6) & 7) << 8)], - iq1s_grid[qs[5] | (((qh[ib+1] >> 3) & 7) << 8)], iq1s_grid[qs[4] | (((qh[ib+1] >> 0) & 7) << 8)]); + iq1s_grid[qs[5] | (((qh[ib+1] >> 3) & 7) << 8)], iq1s_grid[qs[4] | (((qh[ib+1] >> 0) & 7) << 8)]); qs += 8; const __m256i q8b_1 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32; const __m256i q8b_2 = _mm256_loadu_si256((const __m256i*)q8); q8 += 32;