diff --git a/ggml-metal.m b/ggml-metal.m index 15d265fa0..2e5e66381 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -743,7 +743,7 @@ void ggml_metal_graph_compute( src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_Q5_K) { + else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q5_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q6_K) { diff --git a/ggml-metal.metal b/ggml-metal.metal index a11fcfad2..11492f814 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1356,11 +1356,13 @@ kernel void kernel_mul_mat_q3_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - const int row = 2 * r0 + sgitg; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb; + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb; device const float * yy = (device const float *) src1 + r1*ne10; + float yl[16]; + #if QK_K == 256 const uint16_t kmask1 = 0x0303; @@ -1390,43 +1392,58 @@ kernel void kernel_mul_mat_q3_K_f32( const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - float sumf1 = 0, sumf2 = 0; + const int step = sizeof(block_q3_K) * nb; + + device const float * y1 = yy + ix*QK_K + y_offset; + + float sumf1[2] = {0.f}, sumf2[2] = {0.f}; for (int i = ix; i < nb; i += 2) { - const float d_all = (float)(x[i].d); + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; + yl[l+8] = y1[l+16]; + } device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); - device const float * y = yy + i * QK_K + y_offset; + device const uint16_t * a = (device const uint16_t *)(x[i].scales); + device const half * dh = &x[i].d; - device const uint16_t * a = (device const uint16_t *)x[i].scales; - const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); + for (int row = 0; row < 2; ++row) { + + const float d_all = (float)dh[0]; + const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); + + float s1 = 0, s2 = 0; + for (int l = 0; l < n; l += 2) { + const uint16_t qs = q[l/2]; + s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); + s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); + } + float d = d_all * (s1 + 1.f/256.f * s2); + sumf1[row] += d * scales[0]; + sumf2[row] += d; + + s1 = s2 = 0; + for (int l = 0; l < n; l += 2) { + const uint16_t qs = q[l/2+8]; + s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); + s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); + } + d = d_all * (s1 + 1.f/256.f * s2); + sumf1[row] += d * scales[1]; + sumf2[row] += d; + + q += step/2; + h += step/2; + a += step/2; + dh += step/2; - float s1 = 0, s2 = 0; - for (int l = 0; l < n; l += 2) { - const uint16_t qs = q[l/2]; - s1 += y[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); - s2 += y[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); } - float d = d_all * (s1 + 1.f/256.f * s2); - sumf1 += d * scales[0]; - sumf2 += d; - y += 16; - q += 8; - h += 8; - s1 = s2 = 0; - for (int l = 0; l < n; l += 2) { - const uint16_t qs = q[l/2]; - s1 += y[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); - s2 += y[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); - } - d = d_all * (s1 + 1.f/256.f * s2); - sumf1 += d * scales[1]; - sumf2 += d; + y1 += 2 * QK_K; } - const float sumf = (sumf1 - 32.f*sumf2) / (1 << shift); #else const int ix = tiisg/4; const int il = 4 * (tiisg%4);// 0, 4, 8, 12 @@ -1466,9 +1483,12 @@ kernel void kernel_mul_mat_q3_K_f32( (sum2[0] + sum2[1] * 1.f/4.f + sum2[2] * 1.f/16.f + sum2[3] * 1.f/64.f) * 1.f/256.f; #endif - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + row] = tot; + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = tot; + } } }