Additional Q3_K speedup on Metal
This commit is contained in:
parent
5bb23b5ab5
commit
8dba28c00a
2 changed files with 52 additions and 32 deletions
|
@ -743,7 +743,7 @@ void ggml_metal_graph_compute(
|
||||||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q5_K) {
|
else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q5_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q6_K) {
|
||||||
|
|
|
@ -1356,11 +1356,13 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
|
|
||||||
const int row = 2 * r0 + sgitg;
|
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
|
||||||
|
|
||||||
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
|
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
|
float yl[16];
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
|
|
||||||
const uint16_t kmask1 = 0x0303;
|
const uint16_t kmask1 = 0x0303;
|
||||||
|
@ -1390,43 +1392,58 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
const int q_offset = 32*ip + l0;
|
const int q_offset = 32*ip + l0;
|
||||||
const int y_offset = 128*ip + 32*il + l0;
|
const int y_offset = 128*ip + 32*il + l0;
|
||||||
|
|
||||||
float sumf1 = 0, sumf2 = 0;
|
const int step = sizeof(block_q3_K) * nb;
|
||||||
|
|
||||||
|
device const float * y1 = yy + ix*QK_K + y_offset;
|
||||||
|
|
||||||
|
float sumf1[2] = {0.f}, sumf2[2] = {0.f};
|
||||||
for (int i = ix; i < nb; i += 2) {
|
for (int i = ix; i < nb; i += 2) {
|
||||||
|
|
||||||
const float d_all = (float)(x[i].d);
|
for (int l = 0; l < 8; ++l) {
|
||||||
|
yl[l+0] = y1[l+ 0];
|
||||||
|
yl[l+8] = y1[l+16];
|
||||||
|
}
|
||||||
|
|
||||||
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
|
||||||
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
|
device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0);
|
||||||
device const float * y = yy + i * QK_K + y_offset;
|
device const uint16_t * a = (device const uint16_t *)(x[i].scales);
|
||||||
|
device const half * dh = &x[i].d;
|
||||||
|
|
||||||
device const uint16_t * a = (device const uint16_t *)x[i].scales;
|
for (int row = 0; row < 2; ++row) {
|
||||||
|
|
||||||
|
const float d_all = (float)dh[0];
|
||||||
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
|
||||||
|
|
||||||
float s1 = 0, s2 = 0;
|
float s1 = 0, s2 = 0;
|
||||||
for (int l = 0; l < n; l += 2) {
|
for (int l = 0; l < n; l += 2) {
|
||||||
const uint16_t qs = q[l/2];
|
const uint16_t qs = q[l/2];
|
||||||
s1 += y[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
||||||
s2 += y[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
||||||
}
|
}
|
||||||
float d = d_all * (s1 + 1.f/256.f * s2);
|
float d = d_all * (s1 + 1.f/256.f * s2);
|
||||||
sumf1 += d * scales[0];
|
sumf1[row] += d * scales[0];
|
||||||
sumf2 += d;
|
sumf2[row] += d;
|
||||||
|
|
||||||
y += 16;
|
|
||||||
q += 8;
|
|
||||||
h += 8;
|
|
||||||
s1 = s2 = 0;
|
s1 = s2 = 0;
|
||||||
for (int l = 0; l < n; l += 2) {
|
for (int l = 0; l < n; l += 2) {
|
||||||
const uint16_t qs = q[l/2];
|
const uint16_t qs = q[l/2+8];
|
||||||
s1 += y[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1));
|
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1));
|
||||||
s2 += y[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2));
|
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2));
|
||||||
}
|
}
|
||||||
d = d_all * (s1 + 1.f/256.f * s2);
|
d = d_all * (s1 + 1.f/256.f * s2);
|
||||||
sumf1 += d * scales[1];
|
sumf1[row] += d * scales[1];
|
||||||
sumf2 += d;
|
sumf2[row] += d;
|
||||||
|
|
||||||
|
q += step/2;
|
||||||
|
h += step/2;
|
||||||
|
a += step/2;
|
||||||
|
dh += step/2;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
y1 += 2 * QK_K;
|
||||||
|
|
||||||
}
|
}
|
||||||
const float sumf = (sumf1 - 32.f*sumf2) / (1 << shift);
|
|
||||||
#else
|
#else
|
||||||
const int ix = tiisg/4;
|
const int ix = tiisg/4;
|
||||||
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
|
||||||
|
@ -1466,9 +1483,12 @@ kernel void kernel_mul_mat_q3_K_f32(
|
||||||
(sum2[0] + sum2[1] * 1.f/4.f + sum2[2] * 1.f/16.f + sum2[3] * 1.f/64.f) * 1.f/256.f;
|
(sum2[0] + sum2[1] * 1.f/4.f + sum2[2] * 1.f/16.f + sum2[3] * 1.f/64.f) * 1.f/256.f;
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
for (int row = 0; row < 2; ++row) {
|
||||||
|
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
|
||||||
const float tot = simd_sum(sumf);
|
const float tot = simd_sum(sumf);
|
||||||
if (tiisg == 0) {
|
if (tiisg == 0) {
|
||||||
dst[r1*ne0 + row] = tot;
|
dst[r1*ne0 + first_row + row] = tot;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue