Faster Q6_K on Metal
This commit is contained in:
parent
785829dfe8
commit
fa9d54e36e
2 changed files with 30 additions and 43 deletions
11
ggml-metal.m
11
ggml-metal.m
|
@ -712,8 +712,8 @@ void ggml_metal_graph_compute(
|
||||||
GGML_ASSERT(ne02 == 1);
|
GGML_ASSERT(ne02 == 1);
|
||||||
GGML_ASSERT(ne12 == 1);
|
GGML_ASSERT(ne12 == 1);
|
||||||
|
|
||||||
nth0 = 4;
|
nth0 = 2;
|
||||||
nth1 = 16;
|
nth1 = 32;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
||||||
} break;
|
} break;
|
||||||
default:
|
default:
|
||||||
|
@ -743,11 +743,12 @@ void ggml_metal_graph_compute(
|
||||||
src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
}
|
||||||
else if (src0t == GGML_TYPE_Q2_K ||
|
else if (src0t == GGML_TYPE_Q2_K ||
|
||||||
src0t == GGML_TYPE_Q3_K ||
|
src0t == GGML_TYPE_Q3_K ||
|
||||||
src0t == GGML_TYPE_Q4_K ||
|
src0t == GGML_TYPE_Q5_K) {
|
||||||
src0t == GGML_TYPE_Q5_K ||
|
|
||||||
src0t == GGML_TYPE_Q6_K) {
|
|
||||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1766,10 +1766,9 @@ kernel void kernel_mul_mat_q6_K_f32(
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
threadgroup float * sum [[threadgroup(0)]],
|
|
||||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint2 tpitg[[thread_position_in_threadgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint2 tptg[[threads_per_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
const uint8_t kmask1 = 0x03;
|
const uint8_t kmask1 = 0x03;
|
||||||
const uint8_t kmask2 = 0x0C;
|
const uint8_t kmask2 = 0x0C;
|
||||||
|
@ -1781,19 +1780,18 @@ kernel void kernel_mul_mat_q6_K_f32(
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
|
|
||||||
device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
|
const int row = 2 * r0 + sgitg;
|
||||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
|
||||||
|
|
||||||
const int nth = tptg.x*tptg.y;
|
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
|
||||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
float sumf = 0;
|
float sumf = 0;
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
const int tid = tiisg/2;
|
||||||
const int iqs = 16 * tpitg.y;
|
const int ix = tiisg%2;
|
||||||
const int ip = iqs / 128; // 0 or 1
|
const int ip = tid/8; // 0 or 1
|
||||||
const int il = (iqs - 128*ip)/16; // 0...7
|
const int il = tid%8;
|
||||||
const int n = 4;
|
const int n = 4;
|
||||||
const int l0 = n*il;
|
const int l0 = n*il;
|
||||||
const int is = 8*ip + l0/16;
|
const int is = 8*ip + l0/16;
|
||||||
|
@ -1802,9 +1800,10 @@ kernel void kernel_mul_mat_q6_K_f32(
|
||||||
const int q_offset_l = 64*ip + l0;
|
const int q_offset_l = 64*ip + l0;
|
||||||
const int q_offset_h = 32*ip + l0;
|
const int q_offset_h = 32*ip + l0;
|
||||||
|
|
||||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
for (int i = ix; i < nb; i += 2) {
|
||||||
|
|
||||||
device const uint8_t * ql = x[i].ql + q_offset_l;
|
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
||||||
|
device const uint8_t * q2 = q1 + 32;
|
||||||
device const uint8_t * qh = x[i].qh + q_offset_h;
|
device const uint8_t * qh = x[i].qh + q_offset_h;
|
||||||
device const int8_t * sc = x[i].scales + is;
|
device const int8_t * sc = x[i].scales + is;
|
||||||
|
|
||||||
|
@ -1814,19 +1813,21 @@ kernel void kernel_mul_mat_q6_K_f32(
|
||||||
|
|
||||||
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
||||||
for (int l = 0; l < n; ++l) {
|
for (int l = 0; l < n; ++l) {
|
||||||
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
||||||
sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
||||||
sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
||||||
sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
||||||
|
|
||||||
}
|
}
|
||||||
#else
|
|
||||||
const int il = 4*tpitg.x; // 0, 4, 8, 12
|
|
||||||
|
|
||||||
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
#else
|
||||||
|
const int ix = tiisg/4;
|
||||||
|
const int il = 4*(tiisg%4);
|
||||||
|
|
||||||
|
for (int i = ix; i < nb; i += 8) {
|
||||||
device const float * y = yy + i * QK_K + il;
|
device const float * y = yy + i * QK_K + il;
|
||||||
device const uint8_t * ql = x[i].ql + il;
|
device const uint8_t * ql = x[i].ql + il;
|
||||||
device const uint8_t * qh = x[i].qh + il;
|
device const uint8_t * qh = x[i].qh + il;
|
||||||
|
@ -1846,23 +1847,8 @@ kernel void kernel_mul_mat_q6_K_f32(
|
||||||
|
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
sum[ith] = sumf;
|
const float tot = simd_sum(sumf);
|
||||||
|
if (tiisg == 0) {
|
||||||
//
|
dst[r1*ne0 + row] = tot;
|
||||||
// Accumulate the sum from all threads in the threadgroup
|
|
||||||
//
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
if (ith%4 == 0) {
|
|
||||||
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
|
||||||
}
|
}
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
if (ith%16 == 0) {
|
|
||||||
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
|
||||||
if (ith == 0) {
|
|
||||||
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
|
||||||
dst[r1*ne0 + r0] = sum[0];
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue