From 355e8c6e955994761d7687afdc7fdce07526e60b Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Sat, 10 Jun 2023 19:05:32 +0300 Subject: [PATCH] metal : some more optimizations Q2_K: 25.4 ms/token Q6_K: 27.3 ms/token Q4_0: 22.8 ms/token Q4_1: 23.1 ms/token --- ggml-metal.metal | 172 +++++++++++++---------------------------------- 1 file changed, 48 insertions(+), 124 deletions(-) diff --git a/ggml-metal.metal b/ggml-metal.metal index 789aa5d5e..fd990a4f5 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -304,26 +304,14 @@ kernel void kernel_mul_mat_q4_0_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, constant int64_t & ne0, - constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { const int nb = ne00/QK4_0; - const int8_t m8 = 8; - const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -351,47 +339,32 @@ kernel void kernel_mul_mat_q4_0_f32( for (int j = 0; j < 4; ++j) { - acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8); - acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8); + acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4); + acc[1] += yl[j] + yl[j+16]; } - sumf += d * (acc[0] + acc[1]); + sumf += d * (acc[0] - 8.f*acc[1]); } sum[ith] = sumf; // // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. // threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + for (uint i = 16; i < nth; i += 16) sum[0] += sum[i]; dst[r1*ne0 + r0] = sum[0]; } - - //// accumulate the sum from all threads in the threadgroup - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = nth/2; i > 0; i /= 2) { - // if (ith < i) { - // sum[ith] += sum[ith + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - - //if (ith == 0) { - // dst[r1*ne0 + r0] = sum[0]; - //} } kernel void kernel_mul_mat_q4_1_f32( @@ -399,20 +372,10 @@ kernel void kernel_mul_mat_q4_1_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, constant int64_t & ne0, - constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { const int nb = ne00/QK4_1; @@ -460,11 +423,11 @@ kernel void kernel_mul_mat_q4_1_f32( // threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; + sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; } threadgroup_barrier(mem_flags::mem_threadgroup); if (ith == 0) { @@ -847,20 +810,10 @@ kernel void kernel_mul_mat_q2_k_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, constant int64_t & ne0, - constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], // we don't use this for now uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { @@ -875,7 +828,6 @@ kernel void kernel_mul_mat_q2_k_f32( const int nth = tptg.x*tptg.y; const int ith = tptg.y*tpitg.x + tpitg.y; - const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 const int ir = tid%4; // 0...3 @@ -885,35 +837,54 @@ kernel void kernel_mul_mat_q2_k_f32( const int n = 8; const int is = 4*il + (n*ir)/16; + const int y_offset = 64*il + n*ir; + const int q_offset = 32*ip + n*ir; + sum[ith] = 0.0f; float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uint8_t * q = x[i].qs + 32*ip + n*ir; + device const uint8_t * q = x[i].qs + q_offset; device const uint8_t * scales = x[i].scales + is; uint8_t d1 = scales[0] & 0xF; - uint8_t m1 = scales[0] >> 4; uint8_t d2 = scales[2] & 0xF; + uint8_t m1 = scales[0] >> 4; uint8_t m2 = scales[2] >> 4; - device const float * y = yy + i*QK_K + 64*il + n*ir; + device const float * y = yy + i*QK_K + y_offset; + + //float4 s = {0.f, 0.f, 0.f, 0.f}; + float2 s = {0.f, 0.f}; + float smin = 0; + for (int l = 0; l < n; ++l) { + s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); + s[1] += y[l+32] * ((q[l] >> shift2) & 3); + smin += y[l+ 0] * m1 + y[l+32] * m2; + } const float dall = (float)x[i].d; const float dmin = (float)x[i].dmin; - float4 s = {0.f, 0.f, 0.f, 0.f}; - for (int l = 0; l < n; ++l) { - s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); s[1] += y[l+ 0]; - s[2] += y[l+32] * ((q[l] >> shift2) & 3); s[3] += y[l+32]; - } - sumf += dall * (s[0] * d1 + s[2] * d2) - dmin * (s[1] * m1 + s[3] * m2); - + sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin; } sum[ith] = sumf; + //int mask1 = (ith%4 == 0); + //int mask2 = (ith%16 == 0); + + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (int i = 1; i < 4; ++i) sum[ith] += mask1 * sum[ith + i]; + //threadgroup_barrier(mem_flags::mem_threadgroup); + //for (int i = 4; i < 16; i += 4) sum[ith] += mask2 * sum[ith + i]; + //threadgroup_barrier(mem_flags::mem_threadgroup); + //if (ith == 0) { + // for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; + // dst[r1*ne0 + r0] = sum[0]; + //} + // // Accumulate the sum from all threads in the threadgroup // This version is slightly faster than the commented out one below, @@ -952,20 +923,10 @@ kernel void kernel_mul_mat_q4_k_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, constant int64_t & ne0, - constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], // we don't use this for now uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { @@ -986,7 +947,6 @@ kernel void kernel_mul_mat_q4_k_f32( const int tid = tpitg.y; // 0...16 const int il = tid/4; // 0...3 - //const int ir = tid%4; // 0...3 const int ir = tid - 4*il;// 0...3 const int n = 4; @@ -999,12 +959,6 @@ kernel void kernel_mul_mat_q4_k_f32( sum[ith] = 0.0f; - //uint16_t aux_scales[4]; - //thread uint8_t * sc = (thread uint8_t *)aux_scales; - - //uint32_t aux32[4]; - //thread const uint8_t * sc = (thread const uint8_t *)aux32; - uchar2 sc1, sc2, sc3, sc4; float sumf = 0; @@ -1018,46 +972,21 @@ kernel void kernel_mul_mat_q4_k_f32( const float dall = (float)((x + i)->d); const float dmin = (float)((x + i)->dmin); - //device const uint32_t * a = (device const uint32_t *)(x + i)->scales; - //aux32[0] = a[0] & 0x3f3f3f3f; // scales for 0, 32, 64, 96 - //aux32[1] = a[1] & 0x3f3f3f3f; // mins for 0, 32, 64, 96 - //aux32[2] = ((a[2] >> 0) & 0x0f0f0f0f) | ((a[0] & 0xc0c0c0c0) >> 2); // scales for 128, 160, 192, 224 - //aux32[3] = ((a[2] >> 4) & 0x0f0f0f0f) | ((a[1] & 0xc0c0c0c0) >> 2); // mins for 128, 160, 192, 224 - - //aux_scales[0] = (uint16_t)(a[im+0] & kmask1); - //aux_scales[1] = (uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2)); - //aux_scales[2] = (uint16_t)(a[im+2] & kmask1); - //aux_scales[3] = (uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2)); - device const uint16_t * a = (device const uint16_t *)(x + i)->scales; sc1 = as_type((uint16_t)(a[im+0] & kmask1)); sc2 = as_type((uint16_t)(a[im+2] & kmask1)); sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); - //float2 s = {0.f, 0.f}; float4 s = {0.f, 0.f, 0.f, 0.f}; float smin = 0; for (int l = 0; l < n; ++l) { - ////s[0] += y1[l] * sc[0] * (q1[l] & 0xF) + y1[l+32] * sc[1] * (q1[l] >> 4) - //// + y2[l] * sc[2] * (q2[l] & 0xF) + y2[l+32] * sc[3] * (q2[l] >> 4); - ////s[1] += y1[l] * sc[4] + y1[l+32] * sc[5] + y2[l] * sc[6] + y2[l+32] * sc[7]; - - ////s[0] += y1[l] * sc[2*im+0] * (q1[l] & 0xF) + y1[l+32] * sc[2*im+1] * (q1[l] >> 4) - //// + y2[l] * sc[2*im+8] * (q2[l] & 0xF) + y2[l+32] * sc[2*im+9] * (q2[l] >> 4); - ////s[1] += y1[l] * sc[2*im+4] + y1[l+32] * sc[2*im+5] + y2[l] * sc[2*im+12] + y2[l+32] * sc[2*im+13]; - - //s[0] += y1[l] * sc1[0] * (q1[l] & 0xF) + y1[l+32] * sc1[1] * (q1[l] >> 4) - // + y2[l] * sc3[0] * (q2[l] & 0xF) + y2[l+32] * sc3[1] * (q2[l] >> 4); - //s[1] += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; - s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4); s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4); smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; } - //sumf += dall * s[0] - dmin * s[1]; sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; } @@ -1102,20 +1031,10 @@ kernel void kernel_mul_mat_q6_k_f32( device const float * src1, device float * dst, constant int64_t & ne00, - constant int64_t & ne01, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, constant int64_t & ne10, - constant int64_t & ne11, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, constant int64_t & ne0, - constant int64_t & ne1, threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpig[[thread_position_in_grid]], // we don't use this for now uint2 tpitg[[thread_position_in_threadgroup]], uint2 tptg[[threads_per_threadgroup]]) { @@ -1135,21 +1054,26 @@ kernel void kernel_mul_mat_q6_k_f32( const uint nth = tptg.x*tptg.y; const uint ith = tptg.y*tpitg.x + tpitg.y; - const int step = QK_K / tptg.y; // we expect this to be 16 - const int iqs = step * tpitg.y; // 0...240 in steps of 16 + // Note: we absolutely assume that tptg.y = 16 and QK_K = 256! + const int iqs = 16 * tpitg.y; const int ip = iqs / 128; // 0 or 1 const int il = (iqs - 128*ip)/16; // 0...7 const int n = 4; - const int is = 8*ip + (n*il)/16; + const int l0 = n*il; + const int is = 8*ip + l0/16; + + const int y_offset = 128*ip + l0; + const int q_offset_l = 64*ip + l0; + const int q_offset_h = 32*ip + l0; float sumf = 0; for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uint8_t * ql = x[i].ql + 64*ip + n*il; - device const uint8_t * qh = x[i].qh + 32*ip + n*il; + device const uint8_t * ql = x[i].ql + q_offset_l; + device const uint8_t * qh = x[i].qh + q_offset_h; device const int8_t * sc = x[i].scales + is; - device const float * y = yy + i * QK_K + 128*ip + n*il; + device const float * y = yy + i * QK_K + y_offset; const float dall = x[i].d;