diff --git a/ggml-metal.metal b/ggml-metal.metal index d0cf17cad..e851cbd4d 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -782,37 +782,29 @@ kernel void kernel_mul_mat_q6_k_f32( const int iqs = step * tpitg.y; // 0...240 in steps of 16 const int ip = iqs / 128; // 0 or 1 const int il = (iqs - 128*ip)/16; // 0...7 - const int is = 8*ip; + const int n = 4; + const int is = 8*ip + (n*il)/16; float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uint8_t * ql = x[i].ql + 64*ip + il; - device const uint8_t * qh = x[i].qh + 32*ip + il; + device const uint8_t * ql = x[i].ql + 64*ip + n*il; + device const uint8_t * qh = x[i].qh + 32*ip + n*il; device const int8_t * sc = x[i].scales + is; - device const float * y = yy + i * QK_K + 128*ip + il; + device const float * y = yy + i * QK_K + 128*ip + n*il; const float dall = x[i].d; - float result = sc[0] * y[ 0] * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & kmask1) << 4)) - 32) - + sc[0] * y[ 8] * ((int8_t)((ql[ 8] & 0xF) | ((qh[ 8] & kmask1) << 4)) - 32) - + sc[1] * y[ 16] * ((int8_t)((ql[16] & 0xF) | ((qh[16] & kmask1) << 4)) - 32) - + sc[1] * y[ 24] * ((int8_t)((ql[24] & 0xF) | ((qh[24] & kmask1) << 4)) - 32) - + sc[2] * y[ 32] * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & kmask2) << 2)) - 32) - + sc[2] * y[ 40] * ((int8_t)((ql[40] & 0xF) | ((qh[ 8] & kmask2) << 2)) - 32) - + sc[3] * y[ 48] * ((int8_t)((ql[48] & 0xF) | ((qh[16] & kmask2) << 2)) - 32) - + sc[3] * y[ 56] * ((int8_t)((ql[56] & 0xF) | ((qh[24] & kmask2) << 2)) - 32) - + sc[4] * y[ 64] * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & kmask3) << 0)) - 32) - + sc[4] * y[ 72] * ((int8_t)((ql[ 8] >> 4) | ((qh[ 8] & kmask3) << 0)) - 32) - + sc[5] * y[ 80] * ((int8_t)((ql[16] >> 4) | ((qh[16] & kmask3) << 0)) - 32) - + sc[5] * y[ 88] * ((int8_t)((ql[24] >> 4) | ((qh[24] & kmask3) << 0)) - 32) - + sc[6] * y[ 96] * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & kmask4) >> 2)) - 32) - + sc[6] * y[104] * ((int8_t)((ql[40] >> 4) | ((qh[ 8] & kmask4) >> 2)) - 32) - + sc[7] * y[112] * ((int8_t)((ql[48] >> 4) | ((qh[16] & kmask4) >> 2)) - 32) - + sc[7] * y[120] * ((int8_t)((ql[56] >> 4) | ((qh[24] & kmask4) >> 2)) - 32); + float4 sums = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + } - sumf += dall * result; + sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); }