diff --git a/ggml-metal.metal b/ggml-metal.metal index a6c48c619..ff0198b62 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -4349,8 +4349,10 @@ void kernel_mul_mv_iq1_s_f32_impl( for (int ib32 = ix; ib32 < nb32; ib32 += 32) { + float sumy = 0; for (int i = 0; i < 32; ++i) { yl[i] = y4[i]; + sumy += yl[i]; } const int ibl = ib32 / (QK_K / 32); @@ -4363,16 +4365,19 @@ void kernel_mul_mv_iq1_s_f32_impl( for (int row = 0; row < N_DST; row++) { - constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700))); - constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 5) & 0x700))); - constant int8_t * grid3 = (constant int8_t *)(iq1s_grid + (qs[2] | ((qh[0] << 2) & 0x700))); - constant int8_t * grid4 = (constant int8_t *)(iq1s_grid + (qs[3] | ((qh[0] >> 1) & 0x700))); + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid + (qs[0] | ((qh[0] << 8) & 0x700))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid + (qs[1] | ((qh[0] << 5) & 0x700))); + constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid + (qs[2] | ((qh[0] << 2) & 0x700))); + constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid + (qs[3] | ((qh[0] >> 1) & 0x700))); float sum = 0; - for (int j = 0; j < 8; ++j) { - sum += yl[j+0] * grid1[j] + yl[j+8] * grid2[j] + yl[j+16] * grid3[j] + yl[j+24] * grid4[j]; + for (int j = 0; j < 4; ++j) { + sum += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4) + + yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4) + + yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4) + + yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4); } - sumf[row] += (float)dh[0] * sum * (2*(qh[0] >> 12) + 1); + sumf[row] += (float)dh[0] * (sum - sumy) * (2*(qh[0] >> 12) + 1); dh += nb*sizeof(block_iq1_s)/2; qs += nb*sizeof(block_iq1_s); @@ -5072,11 +5077,13 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & device const uint8_t * qs = xb->qs + 4*ib32 + 2*il; device const uint16_t * qh = xb->qh; const float dl = d * (2*(qh[ib32] >> 12) + 1); - constant int8_t * grid1 = (constant int8_t *)(iq1s_grid + (qs[0] | (((qh[ib32] >> (6*il+0)) & 7) << 8))); - constant int8_t * grid2 = (constant int8_t *)(iq1s_grid + (qs[1] | (((qh[ib32] >> (6*il+3)) & 7) << 8))); - for (int i = 0; i < 8; ++i) { - reg[i/4+0][i%4] = dl * grid1[i]; - reg[i/4+2][i%4] = dl * grid2[i]; + constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid + (qs[0] | (((qh[ib32] >> (6*il+0)) & 7) << 8))); + constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid + (qs[1] | (((qh[ib32] >> (6*il+3)) & 7) << 8))); + for (int i = 0; i < 4; ++i) { + reg[0][i] = dl * (grid1[i] & 0xf) - dl; + reg[1][i] = dl * (grid1[i] >> 4) - dl; + reg[2][i] = dl * (grid2[i] & 0xf) - dl; + reg[3][i] = dl * (grid2[i] >> 4) - dl; } }