From 5bb23b5ab5ddd2377bbb91abe2ac468859156010 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Thu, 20 Jul 2023 20:07:38 +0300 Subject: [PATCH] Faster Q3_K on Metal --- ggml-metal.m | 10 ++-- ggml-metal.metal | 116 +++++++++++++++++++++++------------------------ 2 files changed, 59 insertions(+), 67 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 135bda9fc..15d265fa0 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -685,8 +685,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32]; } break; case GGML_TYPE_Q4_K: @@ -746,12 +746,8 @@ void ggml_metal_graph_compute( else if (src0t == GGML_TYPE_Q5_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_Q6_K) { + else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; - } - else if (src0t == GGML_TYPE_Q3_K) { - [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; diff --git a/ggml-metal.metal b/ggml-metal.metal index 97f5c10ba..a11fcfad2 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -1347,40 +1347,41 @@ kernel void kernel_mul_mat_q3_K_f32( constant int64_t & ne10, constant int64_t & ne0, constant int64_t & ne1, - threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q3_K * x = (device const block_q3_K *) src0 + r0*nb; + const int row = 2 * r0 + sgitg; + + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb; device const float * yy = (device const float *) src1 + r1*ne10; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; - #if QK_K == 256 - const uint8_t m3 = 3; - const int8_t m4 = 4; - const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; - const int tid = tpitg.y; // expecting 16 + const int tid = tiisg/2; + const int ix = tiisg%2; const int ip = tid/8; // 0 or 1 const int il = tid/2 - 4*ip; // 0...3 const int ir = tid%2; const int n = 8; const int l0 = n*ir; - const uint8_t m = 1 << (4*ip + il); + const uint16_t m1 = 1 << (4*ip + il); + const uint16_t m2 = m1 << 8; const int shift = 2*il; + const uint16_t qm1 = 0x0003 << shift; + const uint16_t qm2 = 0x0300 << shift; + const int32_t v1 = 4 << shift; + const int32_t v2 = 1024 << shift; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + 2*(il/2); @@ -1389,90 +1390,85 @@ kernel void kernel_mul_mat_q3_K_f32( const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - //float sumf = 0; float sumf1 = 0, sumf2 = 0; - for (int i = tpitg.x; i < nb; i += tptg.x) { + for (int i = ix; i < nb; i += 2) { const float d_all = (float)(x[i].d); - device const uint8_t * q = x[i].qs + q_offset; - device const uint8_t * h = x[i].hmask + l0; - device const float * y = yy + i * QK_K + y_offset; + device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + l0); + device const float * y = yy + i * QK_K + y_offset; device const uint16_t * a = (device const uint16_t *)x[i].scales; const char2 scales = as_type((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4))); - float s = 0; - for (int l = 0; l < n; ++l) { - s += y[l+ 0] * ((int8_t)((q[l+ 0] >> shift) & m3) - ((h[l+ 0] & m) ? 0 : m4)); + float s1 = 0, s2 = 0; + for (int l = 0; l < n; l += 2) { + const uint16_t qs = q[l/2]; + s1 += y[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); + s2 += y[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); } - float d = d_all * s; + float d = d_all * (s1 + 1.f/256.f * s2); sumf1 += d * scales[0]; sumf2 += d; - //sumf += d_all * s * (scales[0] - 32); - s = 0; - for (int l = 0; l < n; ++l) { - s += y[l+16] * ((int8_t)((q[l+16] >> shift) & m3) - ((h[l+16] & m) ? 0 : m4)); + y += 16; + q += 8; + h += 8; + s1 = s2 = 0; + for (int l = 0; l < n; l += 2) { + const uint16_t qs = q[l/2]; + s1 += y[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); + s2 += y[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); } - d = d_all * s; + d = d_all * (s1 + 1.f/256.f * s2); sumf1 += d * scales[1]; sumf2 += d; - //sumf += d_all * s * (scales[1] - 32); } - - //sum[ith] = sumf; - sum[ith] = sumf1 - 32.f*sumf2; + const float sumf = (sumf1 - 32.f*sumf2) / (1 << shift); #else - const int il = 4 * tpitg.x; // 0, 4, 8, 12 + const int ix = tiisg/4; + const int il = 4 * (tiisg%4);// 0, 4, 8, 12 const int im = il/8; // 0, 0, 1, 1 const int in = il%8; // 0, 4, 0, 4 - float sumf = 0; + float4 sum1 = {0.f, 0.f, 0.f, 0.f}; + float4 sum2 = {0.f, 0.f, 0.f, 0.f}; - for (int i = tpitg.y; i < nb; i += tptg.y) { + for (int i = ix; i < nb; i += 8) { const float d_all = (float)(x[i].d); - device const uint8_t * q = x[i].qs + il; - device const uint8_t * h = x[i].hmask + in; - device const float * y = yy + i * QK_K + il; + device const uint16_t * q = (device const uint16_t *)(x[i].qs + il); + device const uint16_t * h = (device const uint16_t *)(x[i].hmask + in); + device const float * y = yy + i * QK_K + il; const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8); const float d2 = d_all * ((x[i].scales[0] >> 4) - 8); const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8); const float d4 = d_all * ((x[i].scales[1] >> 4) - 8); - for (int l = 0; l < 4; ++l) { - const uint8_t hm = h[l] >> im; - sumf += y[l+ 0] * d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((hm & 0x01) ? 0 : 4)) - + y[l+16] * d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((hm & 0x04) ? 0 : 4)) - + y[l+32] * d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((hm & 0x10) ? 0 : 4)) - + y[l+48] * d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((hm & 0x40) ? 0 : 4)); + for (int l = 0; l < 4; l += 2) { + const uint16_t hm = h[l/2] >> im; + sum1[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)); + sum1[1] += y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)); + sum1[2] += y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)); + sum1[3] += y[l+48] * d4 * ((int32_t)(q[l/2] & 0x00c0) - ((hm & 0x0040) ? 0 : 256)); + sum2[0] += y[l+ 1] * d1 * ((int32_t)(q[l/2] & 0x0300) - ((hm & 0x0100) ? 0 : 1024)); + sum2[1] += y[l+17] * d2 * ((int32_t)(q[l/2] & 0x0c00) - ((hm & 0x0400) ? 0 : 4096)); + sum2[2] += y[l+33] * d3 * ((int32_t)(q[l/2] & 0x3000) - ((hm & 0x1000) ? 0 : 16384)); + sum2[3] += y[l+49] * d4 * ((int32_t)(q[l/2] & 0xc000) - ((hm & 0x4000) ? 0 : 65536)); } } - - sum[ith] = sumf; - + const float sumf = sum1[0] + sum1[1] * 1.f/4.f + sum1[2] * 1.f/16.f + sum1[3] * 1.f/64.f + + (sum2[0] + sum2[1] * 1.f/4.f + sum2[2] * 1.f/16.f + sum2[3] * 1.f/64.f) * 1.f/256.f; #endif - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + row] = tot; } }