diff --git a/ggml-metal.metal b/ggml-metal.metal index e985f8a9e..f01be34db 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1087,23 +1087,27 @@ kernel void kernel_mul_mat_q3_k_f32( const int n = 8; const int l0 = n*ir; + const uint8_t m = 1 << (4*ip + il); + + const int shift = 2*il; + const int is = 8*ip + 2*il; + + //const int shift2 = 4*ip; + uint16_t aux[8]; thread const int8_t * scales = (thread const int8_t*)aux; + const int q_offset = 32*ip + l0; + const int y_offset = 128*ip + 32*il + l0; + float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { const float d_all = (float)(x[i].d); - device const uint8_t * q = x[i].qs + 32*ip + l0; + device const uint8_t * q = x[i].qs + q_offset; device const uint8_t * h = x[i].hmask + l0; - device const float * y = yy + i * QK_K + 128*ip + 32*il + l0; - - //device const uint32_t * a = (device const uint32_t *)x[i].scales; - //aux[0] = (a[0] & kmask2) | (((a[2] >> 0) & kmask1) << 4); - //aux[1] = (a[1] & kmask2) | (((a[2] >> 2) & kmask1) << 4); - //aux[2] = ((a[0] >> 4) & kmask2) | (((a[2] >> 4) & kmask1) << 4); - //aux[3] = ((a[1] >> 4) & kmask2) | (((a[2] >> 6) & kmask1) << 4); + device const float * y = yy + i * QK_K + y_offset; device const uint16_t * a = (device const uint16_t *)x[i].scales; aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4); @@ -1115,29 +1119,30 @@ kernel void kernel_mul_mat_q3_k_f32( aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4); aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4); - uint8_t m = 1 << (4*ip + il); - int is = 8*ip + 2*il; float dl; - //for (int n = 0; n < QK_K; n += 128) { - int shift = 2*il; - //for (int j = 0; j < 4; ++j) { - dl = d_all * (scales[is++] - 32); - for (int l = 0; l < n; ++l) { - sumf += y[l+ 0] * dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4)); - } - - dl = d_all * (scales[is++] - 32); - for (int l = 0; l < n; ++l) { - sumf += y[l+16] * dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4)); - } - - y += 32; - shift += 2; - m <<= 1; - //} - //q += 32; + //dl = d_all * (scales[is+0] - 32); + //for (int l = 0; l < n; ++l) { + // sumf += y[l+ 0] * dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4)); //} + + //dl = d_all * (scales[is+1] - 32); + //for (int l = 0; l < n; ++l) { + // sumf += y[l+16] * dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4)); + //} + + float s = 0; + for (int l = 0; l < n; ++l) { + s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4)); + } + sumf += s * d_all * (scales[is+0] - 32); + + s = 0; + for (int l = 0; l < n; ++l) { + s += y[l+16] * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4)); + } + sumf += s * d_all * (scales[is+1] - 32); + } sum[ith] = sumf;