From 0099570f043d6420d33cde82f8a6e87cb8ff7113 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 21 Jul 2023 09:14:12 +0300 Subject: [PATCH] Q3_K for QK_K = 64 --- ggml-metal.m | 11 ++++++++-- ggml-metal.metal | 57 ++++++++++++++++++++++++++++++++++-------------- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 2e5e66381..2810fa2a8 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -743,10 +743,17 @@ void ggml_metal_graph_compute( src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q5_K) { + else if (src0t == GGML_TYPE_Q3_K) { +#ifdef GGML_QKK_64 + [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#else + [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; +#endif + } + else if (src0t == GGML_TYPE_Q5_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q6_K) { + else if (src0t == GGML_TYPE_Q6_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0]; diff --git a/ggml-metal.metal b/ggml-metal.metal index 11492f814..e56fe2248 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -351,7 +351,7 @@ kernel void kernel_rms_norm( threadgroup_barrier(mem_flags::mem_threadgroup); // broadcast, simd group number is ntg / 32 - for (int i = ntg / 32 / 2; i > 0; i /= 2) { + for (uint i = ntg / 32 / 2; i > 0; i /= 2) { if (tpitg < i) { sum[tpitg] += sum[tpitg + i]; } @@ -1339,6 +1339,7 @@ kernel void kernel_mul_mat_q2_K_f32( } } +#if QK_K == 256 kernel void kernel_mul_mat_q3_K_f32( device const void * src0, device const float * src1, @@ -1363,8 +1364,6 @@ kernel void kernel_mul_mat_q3_K_f32( float yl[16]; -#if QK_K == 256 - const uint16_t kmask1 = 0x0303; const uint16_t kmask2 = 0x0f0f; @@ -1392,7 +1391,7 @@ kernel void kernel_mul_mat_q3_K_f32( const int q_offset = 32*ip + l0; const int y_offset = 128*ip + 32*il + l0; - const int step = sizeof(block_q3_K) * nb; + const int step = sizeof(block_q3_K) * nb / 2; device const float * y1 = yy + ix*QK_K + y_offset; @@ -1434,17 +1433,47 @@ kernel void kernel_mul_mat_q3_K_f32( sumf1[row] += d * scales[1]; sumf2[row] += d; - q += step/2; - h += step/2; - a += step/2; - dh += step/2; + q += step; + h += step; + a += step; + dh += step; } y1 += 2 * QK_K; } + + for (int row = 0; row < 2; ++row) { + const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = tot; + } + } +} #else +kernel void kernel_mul_mat_q3_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne1, + uint2 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int nb = ne00/QK_K; + + const int64_t r0 = tgpig.x; + const int64_t r1 = tgpig.y; + + const int row = 2 * r0 + sgitg; + + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb; + device const float * yy = (device const float *) src1 + r1*ne10; const int ix = tiisg/4; const int il = 4 * (tiisg%4);// 0, 4, 8, 12 const int im = il/8; // 0, 0, 1, 1 @@ -1481,17 +1510,14 @@ kernel void kernel_mul_mat_q3_K_f32( } const float sumf = sum1[0] + sum1[1] * 1.f/4.f + sum1[2] * 1.f/16.f + sum1[3] * 1.f/64.f + (sum2[0] + sum2[1] * 1.f/4.f + sum2[2] * 1.f/16.f + sum2[3] * 1.f/64.f) * 1.f/256.f; -#endif - for (int row = 0; row < 2; ++row) { - const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); - const float tot = simd_sum(sumf); - if (tiisg == 0) { - dst[r1*ne0 + first_row + row] = tot; - } + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + row] = tot; } } +#endif #if QK_K == 256 kernel void kernel_mul_mat_q4_K_f32( @@ -1789,7 +1815,6 @@ kernel void kernel_mul_mat_q5_K_f32( for (int i = ix; i < nb; i += 8) { - float4 sumy = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < 4; ++l) { yl[l+0] = y[l+ 0]; yl[l+4] = y[l+16];