Faster Q6_K on Metal
This commit is contained in:
parent
785829dfe8
commit
fa9d54e36e
2 changed files with 30 additions and 43 deletions
11
ggml-metal.m
11
ggml-metal.m
|
@ -712,8 +712,8 @@ void ggml_metal_graph_compute(
|
|||
GGML_ASSERT(ne02 == 1);
|
||||
GGML_ASSERT(ne12 == 1);
|
||||
|
||||
nth0 = 4;
|
||||
nth1 = 16;
|
||||
nth0 = 2;
|
||||
nth1 = 32;
|
||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
|
||||
} break;
|
||||
default:
|
||||
|
@ -743,11 +743,12 @@ void ggml_metal_graph_compute(
|
|||
src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src0t == GGML_TYPE_Q2_K ||
|
||||
src0t == GGML_TYPE_Q3_K ||
|
||||
src0t == GGML_TYPE_Q4_K ||
|
||||
src0t == GGML_TYPE_Q5_K ||
|
||||
src0t == GGML_TYPE_Q6_K) {
|
||||
src0t == GGML_TYPE_Q5_K) {
|
||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
|
|
|
@ -1766,10 +1766,9 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|||
constant int64_t & ne00,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne0,
|
||||
threadgroup float * sum [[threadgroup(0)]],
|
||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||
uint2 tpitg[[thread_position_in_threadgroup]],
|
||||
uint2 tptg[[threads_per_threadgroup]]) {
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const uint8_t kmask1 = 0x03;
|
||||
const uint8_t kmask2 = 0x0C;
|
||||
|
@ -1781,19 +1780,18 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|||
const int64_t r0 = tgpig.x;
|
||||
const int64_t r1 = tgpig.y;
|
||||
|
||||
device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb;
|
||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||
const int row = 2 * r0 + sgitg;
|
||||
|
||||
const int nth = tptg.x*tptg.y;
|
||||
const int ith = tptg.y*tpitg.x + tpitg.y;
|
||||
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
|
||||
device const float * yy = (device const float *) src1 + r1*ne10;
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
#if QK_K == 256
|
||||
// Note: we absolutely assume that tptg.y = 16 and QK_K = 256!
|
||||
const int iqs = 16 * tpitg.y;
|
||||
const int ip = iqs / 128; // 0 or 1
|
||||
const int il = (iqs - 128*ip)/16; // 0...7
|
||||
const int tid = tiisg/2;
|
||||
const int ix = tiisg%2;
|
||||
const int ip = tid/8; // 0 or 1
|
||||
const int il = tid%8;
|
||||
const int n = 4;
|
||||
const int l0 = n*il;
|
||||
const int is = 8*ip + l0/16;
|
||||
|
@ -1802,9 +1800,10 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|||
const int q_offset_l = 64*ip + l0;
|
||||
const int q_offset_h = 32*ip + l0;
|
||||
|
||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||
for (int i = ix; i < nb; i += 2) {
|
||||
|
||||
device const uint8_t * ql = x[i].ql + q_offset_l;
|
||||
device const uint8_t * q1 = x[i].ql + q_offset_l;
|
||||
device const uint8_t * q2 = q1 + 32;
|
||||
device const uint8_t * qh = x[i].qh + q_offset_h;
|
||||
device const int8_t * sc = x[i].scales + is;
|
||||
|
||||
|
@ -1814,19 +1813,21 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|||
|
||||
float4 sums = {0.f, 0.f, 0.f, 0.f};
|
||||
for (int l = 0; l < n; ++l) {
|
||||
sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
||||
sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
||||
sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
||||
sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
||||
sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32);
|
||||
sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32);
|
||||
sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32);
|
||||
sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32);
|
||||
}
|
||||
|
||||
sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]);
|
||||
|
||||
}
|
||||
#else
|
||||
const int il = 4*tpitg.x; // 0, 4, 8, 12
|
||||
|
||||
for (int i = tpitg.y; i < nb; i += tptg.y) {
|
||||
#else
|
||||
const int ix = tiisg/4;
|
||||
const int il = 4*(tiisg%4);
|
||||
|
||||
for (int i = ix; i < nb; i += 8) {
|
||||
device const float * y = yy + i * QK_K + il;
|
||||
device const uint8_t * ql = x[i].ql + il;
|
||||
device const uint8_t * qh = x[i].qh + il;
|
||||
|
@ -1846,23 +1847,8 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|||
|
||||
#endif
|
||||
|
||||
sum[ith] = sumf;
|
||||
|
||||
//
|
||||
// Accumulate the sum from all threads in the threadgroup
|
||||
//
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith%4 == 0) {
|
||||
for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
|
||||
const float tot = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + row] = tot;
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith%16 == 0) {
|
||||
for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
|
||||
}
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
if (ith == 0) {
|
||||
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
|
||||
dst[r1*ne0 + r0] = sum[0];
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue