diff --git a/ggml-metal.metal b/ggml-metal.metal index 119fcbeb6..9901fc983 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1140,7 +1140,7 @@ kernel void kernel_mul_mat_q3_K_f32( float yl[16]; - const uint16_t kmask1 = 0x0303; + const uint16_t kmask1 = 0x3030; const uint16_t kmask2 = 0x0f0f; const int tid = tiisg/2; @@ -1155,10 +1155,9 @@ kernel void kernel_mul_mat_q3_K_f32( const uint16_t m2 = m1 << 8; const int shift = 2*il; - const uint16_t qm1 = 0x0003 << shift; - const uint16_t qm2 = 0x0300 << shift; - const int32_t v1 = 4 << shift; - const int32_t v2 = 1024 << shift; + const int32_t qm1 = 0x0003 << shift; + const int32_t qm2 = 0x0300 << shift; + const float v1 = 4 << shift; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + 2*(il/2); @@ -1171,7 +1170,10 @@ kernel void kernel_mul_mat_q3_K_f32( device const float * y1 = yy + ix*QK_K + y_offset; - float sumf1[2] = {0.f}, sumf2[2] = {0.f}; + uint16_t scales16; + thread const int8_t * scales = (thread const int8_t *)&scales16; + + float sumf1[2] = {0.f}; for (int i = ix; i < nb; i += 2) { for (int l = 0; l < 8; ++l) { @@ -1187,27 +1189,27 @@ kernel void kernel_mul_mat_q3_K_f32( for (int row = 0; row < 2; ++row) { const float d_all = (float)dh[0]; - const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); + scales16 = ((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) << 4) & kmask1); - float s1 = 0, s2 = 0; + float s1 = 0, s2 = 0, s3 = 0; for (int l = 0; l < n; l += 2) { - const uint16_t qs = q[l/2]; - s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); - s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); + const int32_t qs = q[l/2]; + s1 += yl[l+0] * (qs & qm1); + s2 += yl[l+1] * (qs & qm2); + s3 += ((h[l/2] & m1) ? 0.f : yl[l+0]) + ((h[l/2] & m2) ? 0.f : yl[l+1]); } - float d = d_all * (s1 + 1.f/256.f * s2); - sumf1[row] += d * scales[0]; - sumf2[row] += d; + float d = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + sumf1[row] += d * (scales[0] - 32); - s1 = s2 = 0; + s1 = s2 = s3 = 0; for (int l = 0; l < n; l += 2) { - const uint16_t qs = q[l/2+8]; - s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); - s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); + const int32_t qs = q[l/2+8]; + s1 += yl[l+8] * (qs & qm1); + s2 += yl[l+9] * (qs & qm2); + s3 += ((h[l/2+8] & m1) ? 0.f : yl[l+8]) + ((h[l/2+8] & m2) ? 0.f : yl[l+9]); } - d = d_all * (s1 + 1.f/256.f * s2); - sumf1[row] += d * scales[1]; - sumf2[row] += d; + d = d_all * (s1 + 1.f/256.f * s2 - s3*v1); + sumf1[row] += d * (scales[1] - 32); q += step; h += step; @@ -1221,7 +1223,7 @@ kernel void kernel_mul_mat_q3_K_f32( } for (int row = 0; row < 2; ++row) { - const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); + const float sumf = sumf1[row] / (1 << shift); const float tot = simd_sum(sumf); if (tiisg == 0) { dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; @@ -1579,17 +1581,25 @@ kernel void kernel_mul_mat_q5_K_f32( sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); - float4 acc = {0.f, 0.f, 0.f, 0.f}; + float4 acc1 = {0.f}; + float4 acc2 = {0.f}; for (int l = 0; l < n; ++l) { uint8_t h = qh[l]; - acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); - acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); - acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); - acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); + acc1[0] += yl[l+0] * (q1[l] & 0x0F); + acc1[1] += yl[l+8] * (q1[l] & 0xF0); + acc1[2] += yh[l+0] * (q2[l] & 0x0F); + acc1[3] += yh[l+8] * (q2[l] & 0xF0); + acc2[0] += h & hm1 ? yl[l+0] : 0.f; + acc2[1] += h & hm2 ? yl[l+8] : 0.f; + acc2[2] += h & hm3 ? yh[l+0] : 0.f; + acc2[3] += h & hm4 ? yh[l+8] : 0.f; } const float dall = dh[0]; const float dmin = dh[1]; - sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) - + sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) + + sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) + + sc8[4] * (acc1[2] + 16.f*acc2[2]) + + sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); q1 += step;