Additional Q3_K speedup on Metal

This commit is contained in:
Iwan Kawrakow 2023-07-20 20:28:28 +03:00
parent 5bb23b5ab5
commit 8dba28c00a
2 changed files with 52 additions and 32 deletions

View file

@ -743,7 +743,7 @@ void ggml_metal_graph_compute(
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_Q5_K) { else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q5_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q6_K) { else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q6_K) {

View file

@ -1356,11 +1356,13 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r0 = tgpig.x; const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y; const int64_t r1 = tgpig.y;
const int row = 2 * r0 + sgitg; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb; device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
device const float * yy = (device const float *) src1 + r1*ne10; device const float * yy = (device const float *) src1 + r1*ne10;
float yl[16];
#if QK_K == 256 #if QK_K == 256
const uint16_t kmask1 = 0x0303; const uint16_t kmask1 = 0x0303;
@ -1390,43 +1392,58 @@ kernel void kernel_mul_mat_q3_K_f32(
const int q_offset = 32*ip + l0; const int q_offset = 32*ip + l0;
const int y_offset = 128*ip + 32*il + l0; const int y_offset = 128*ip + 32*il + l0;
float sumf1 = 0, sumf2 = 0; const int step = sizeof(block_q3_K) * nb;
device const float * y1 = yy + ix*QK_K + y_offset;
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
for (int i = ix; i < nb; i += 2) { for (int i = ix; i < nb; i += 2) {
const float d_all = (float)(x[i].d); for (int l = 0; l < 8; ++l) {
yl[l+0] = y1[l+ 0];
yl[l+8] = y1[l+16];
}
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
device const float * y = yy + i * QK_K + y_offset; device const uint16_t * a = (device const uint16_t *)(x[i].scales);
device const half * dh = &x[i].d;
device const uint16_t * a = (device const uint16_t *)x[i].scales; for (int row = 0; row < 2; ++row) {
const float d_all = (float)dh[0];
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
float s1 = 0, s2 = 0; float s1 = 0, s2 = 0;
for (int l = 0; l < n; l += 2) { for (int l = 0; l < n; l += 2) {
const uint16_t qs = q[l/2]; const uint16_t qs = q[l/2];
s1 += y[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
s2 += y[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
} }
float d = d_all * (s1 + 1.f/256.f * s2); float d = d_all * (s1 + 1.f/256.f * s2);
sumf1 += d * scales[0]; sumf1[row] += d * scales[0];
sumf2 += d; sumf2[row] += d;
y += 16;
q += 8;
h += 8;
s1 = s2 = 0; s1 = s2 = 0;
for (int l = 0; l < n; l += 2) { for (int l = 0; l < n; l += 2) {
const uint16_t qs = q[l/2]; const uint16_t qs = q[l/2+8];
s1 += y[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
s2 += y[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
} }
d = d_all * (s1 + 1.f/256.f * s2); d = d_all * (s1 + 1.f/256.f * s2);
sumf1 += d * scales[1]; sumf1[row] += d * scales[1];
sumf2 += d; sumf2[row] += d;
q += step/2;
h += step/2;
a += step/2;
dh += step/2;
}
y1 += 2 * QK_K;
} }
const float sumf = (sumf1 - 32.f*sumf2) / (1 << shift);
#else #else
const int ix = tiisg/4; const int ix = tiisg/4;
const int il = 4 * (tiisg%4);// 0, 4, 8, 12 const int il = 4 * (tiisg%4);// 0, 4, 8, 12
@ -1466,9 +1483,12 @@ kernel void kernel_mul_mat_q3_K_f32(
(sum2[0] + sum2[1] * 1.f/4.f + sum2[2] * 1.f/16.f + sum2[3] * 1.f/64.f) * 1.f/256.f; (sum2[0] + sum2[1] * 1.f/4.f + sum2[2] * 1.f/16.f + sum2[3] * 1.f/64.f) * 1.f/256.f;
#endif #endif
for (int row = 0; row < 2; ++row) {
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
const float tot = simd_sum(sumf); const float tot = simd_sum(sumf);
if (tiisg == 0) { if (tiisg == 0) {
dst[r1*ne0 + row] = tot; dst[r1*ne0 + first_row + row] = tot;
}
} }
} }