From 9b4de68e733b0271827504fa0da78ff223db2603 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Fri, 9 Jun 2023 10:26:42 +0300 Subject: [PATCH] metal : 17% faster Q4_0 Use 64 threads in a thread group. --- ggml-metal.m | 2 +- ggml-metal.metal | 39 ++++++++++----------------------------- 2 files changed, 11 insertions(+), 30 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index ac4f1346c..54cbaf860 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -526,7 +526,7 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne12 == 1); nth0 = 8; - nth1 = 4; + nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32]; } break; case GGML_TYPE_Q2_K: diff --git a/ggml-metal.metal b/ggml-metal.metal index 981bb41af..8e730eb9c 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -278,53 +278,33 @@ kernel void kernel_mul_mat_q4_0_f32( const uint nth = tptg.x*tptg.y; const uint ith = tptg.y*tpitg.x + tpitg.y; - const int first = 4 * tpitg.y; + const int ix = tpitg.y/4; // 0 or 1 + const int iy = tpitg.y - 4*ix; // 0...3 - sum[ith] = 0.0f; + const int first = 4 * iy; float sumf = 0; - //float sumf1 = 0; - //float sumf2 = 0; - for (int i = tpitg.x; i < nb; i += tptg.x) { - device const uchar * x0p = (device const uchar *) (x + i)->qs; - device const float * y0p = (device const float *) (y + i*QK4_0); + for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) { - const float d = (float)((x + i)->d); + const float d = (float)x[i].d; - device const uchar * x0v = x0p + first; - device const float * y0v = y0p + first; - device const float * y1v = y0p + first + 16; + device const uint8_t * xl = x[i].qs + first; + device const float * yl = y + i * QK4_0 + first; - //float3 acc = {0.0f, 0.0f, 0.f}; float2 acc = {0.0f, 0.0f}; for (int j = 0; j < 4; ++j) { - //acc[0] += y0v[j] * (x0v[j] & 0xF); - //acc[1] += y1v[j] * (x0v[j] >> 4); - //acc[2] += y0v[j] + y1v[j]; + acc[0] += yl[j+ 0] * ((int8_t)(xl[j] & 0xF) - m8); + acc[1] += yl[j+16] * ((int8_t)(xl[j] >> 4) - m8); - acc[0] += y0v[j] * ((int8_t)(x0v[j] & 0xF) - m8); - acc[1] += y1v[j] * ((int8_t)(x0v[j] >> 4) - m8); - - //const int x0 = x0v[j] & 0x0F; - //const int x1 = x0v[j] >> 4; - - //const float y0 = y0v[j]; - //const float y1 = y1v[j]; - - //acc += (x0 - 8)*y0 + (x1 - 8)*y1; } - //sum[ith] += acc*d; sumf += d * (acc[0] + acc[1]); - //sumf1 += d * (acc[0] + acc[1]); - //sumf2 += d * acc[2]; } sum[ith] = sumf; - //sum[ith] = sumf1 - 8.f*sumf2; // // Accumulate the sum from all threads in the threadgroup @@ -380,6 +360,7 @@ kernel void kernel_mul_mat_f16_f32( uint3 tpig[[thread_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]]) { + const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; const int64_t im = tgpig.z;