Faster Q6_K on Metal

This commit is contained in:
Iwan Kawrakow 2023-07-20 16:00:22 +03:00
parent 785829dfe8
commit fa9d54e36e
2 changed files with 30 additions and 43 deletions

View file

@ -712,8 +712,8 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
nth0 = 2;
nth1 = 32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
} break;
default:
@ -743,11 +743,12 @@ void ggml_metal_graph_compute(
src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_Q3_K ||
src0t == GGML_TYPE_Q4_K ||
src0t == GGML_TYPE_Q5_K ||
src0t == GGML_TYPE_Q6_K) {
src0t == GGML_TYPE_Q5_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {

View file

@ -1766,10 +1766,9 @@ kernel void kernel_mul_mat_q6_K_f32(
constant int64_t & ne00,
constant int64_t & ne10,
constant int64_t & ne0,
threadgroup float * sum [[threadgroup(0)]],
uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
@ -1781,19 +1780,18 @@ kernel void kernel_mul_mat_q6_K_f32(
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
device const float * yy = (device const float *) src1 + r1*ne10;
const int row = 2 * r0 + sgitg;
const int nth = tptg.x*tptg.y;
const int ith = tptg.y*tpitg.x + tpitg.y;
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
device const float * yy = (device const float *) src1 + r1*ne10;
float sumf = 0;
#if QK_K == 256
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
const int iqs = 16 * tpitg.y;
const int ip = iqs / 128; // 0 or 1
const int il = (iqs - 128*ip)/16; // 0...7
const int tid = tiisg/2;
const int ix = tiisg%2;
const int ip = tid/8; // 0 or 1
const int il = tid%8;
const int n = 4;
const int l0 = n*il;
const int is = 8*ip + l0/16;
@ -1802,9 +1800,10 @@ kernel void kernel_mul_mat_q6_K_f32(
const int q_offset_l = 64*ip + l0;
const int q_offset_h = 32*ip + l0;
for (int i = tpitg.x; i < nb; i += tptg.x) {
for (int i = ix; i < nb; i += 2) {
device const uint8_t * ql = x[i].ql + q_offset_l;
device const uint8_t * q1 = x[i].ql + q_offset_l;
device const uint8_t * q2 = q1 + 32;
device const uint8_t * qh = x[i].qh + q_offset_h;
device const int8_t * sc = x[i].scales + is;
@ -1814,19 +1813,21 @@ kernel void kernel_mul_mat_q6_K_f32(
float4 sums = {0.f, 0.f, 0.f, 0.f};
for (int l = 0; l < n; ++l) {
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
}
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
}
#else
const int il = 4*tpitg.x; // 0, 4, 8, 12
for (int i = tpitg.y; i < nb; i += tptg.y) {
#else
const int ix = tiisg/4;
const int il = 4*(tiisg%4);
for (int i = ix; i < nb; i += 8) {
device const float * y = yy + i * QK_K + il;
device const uint8_t * ql = x[i].ql + il;
device const uint8_t * qh = x[i].qh + il;
@ -1846,23 +1847,8 @@ kernel void kernel_mul_mat_q6_K_f32(
#endif
sum[ith] = sumf;
//
// Accumulate the sum from all threads in the threadgroup
//
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%4 == 0) {
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
const float tot = simd_sum(sumf);
if (tiisg == 0) {
dst[r1*ne0 + row] = tot;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%16 == 0) {
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith == 0) {
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
dst[r1*ne0 + r0] = sum[0];
}
}