From 96d005225fdb7803b1e1465623b935c2c878d8fb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 30 May 2023 22:13:43 +0300 Subject: [PATCH] mtl : mul_mat fixes (still wrong) --- examples/mtl/mtl.m | 30 ++++++++++++++---------------- examples/mtl/mtl.metal | 31 ++++++++++++++----------------- 2 files changed, 28 insertions(+), 33 deletions(-) diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index bd424c23d..8985a0d74 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -377,29 +377,27 @@ int llama_mtl_eval( id id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); - const int64_t ncols0 = gf->nodes[i]->src0->ne[0]; - const int64_t nrows0 = gf->nodes[i]->src0->ne[1]; - - const int64_t ncols1 = gf->nodes[i]->src1->ne[0]; - const int64_t nrows1 = gf->nodes[i]->src1->ne[1]; - - const int64_t ncols = gf->nodes[i]->ne[0]; - const int64_t nrows = gf->nodes[i]->ne[1]; + const int64_t ne00 = gf->nodes[i]->src0->ne[0]; + const int64_t ne01 = gf->nodes[i]->src0->ne[1]; + const int64_t ne10 = gf->nodes[i]->src1->ne[0]; + const int64_t ne11 = gf->nodes[i]->src1->ne[1]; + const int64_t ne0 = gf->nodes[i]->ne[0]; + const int64_t ne1 = gf->nodes[i]->ne[1]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ncols0 length:sizeof(ncols0) atIndex:3]; - [encoder setBytes:&nrows0 length:sizeof(nrows0) atIndex:4]; - [encoder setBytes:&ncols1 length:sizeof(ncols1) atIndex:5]; - [encoder setBytes:&nrows1 length:sizeof(nrows1) atIndex:6]; - [encoder setBytes:&ncols length:sizeof(ncols) atIndex:7]; - [encoder setBytes:&nrows length:sizeof(nrows) atIndex:8]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:5]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8]; - printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ncols0, nrows0, ncols1, nrows1, ncols, nrows); + printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ne00, ne01, ne10, ne11, ne0, ne1); - [encoder dispatchThreadgroups:MTLSizeMake(nrows0, nrows1, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)]; } break; case GGML_OP_GET_ROWS: { diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index f67d24f71..348a432ab 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -144,19 +144,16 @@ kernel void kernel_mul_mat_q4_0( sum[tpitg.x] = 0.0f; for (int i = 0; i < nb; i += tptg.x) { - device const uint4 * x0p = (device const uint4 *) (x + i); + device const uint4 * x0p = (device const uint4 *) (x + i)->qs; device const float4 * y0p = (device const float4 *) (y + i*qk); const uint4 x0 = *x0p; - const uint4 x0l = x0 & uint4(0x0F0F0F0F); - const uint4 x0h = x0 >> 4; + const uint4 x0l = (x0 & uint4(0x0F0F0F0F)); + const uint4 x0h = (x0 & uint4(0xF0F0F0F0)) >> 4; - const int4 x0ls = as_type(x0l) - int4(8); - const int4 x0hs = as_type(x0h) - int4(8); - - thread const uchar * x0lsb = (thread const uchar *) &x0ls; - thread const uchar * x0hsb = (thread const uchar *) &x0hs; + thread const char * x0lsb = (thread const char *) &x0l; + thread const char * x0hsb = (thread const char *) &x0h; const float4 y00 = *(y0p + 0); const float4 y01 = *(y0p + 1); @@ -167,17 +164,17 @@ kernel void kernel_mul_mat_q4_0( const float4 y06 = *(y0p + 6); const float4 y07 = *(y0p + 7); - const float d = (x + i)->d; + const half d = (x + i)->d; sum[tpitg.x] += ( - x0lsb[ 0]*y00[0] + x0lsb[ 1]*y00[1] + x0lsb[ 2]*y00[2] + x0lsb[ 3]*y00[3] + - x0lsb[ 4]*y01[0] + x0lsb[ 5]*y01[1] + x0lsb[ 6]*y01[2] + x0lsb[ 7]*y01[3] + - x0lsb[ 8]*y02[0] + x0lsb[ 9]*y02[1] + x0lsb[10]*y02[2] + x0lsb[11]*y02[3] + - x0lsb[12]*y03[0] + x0lsb[13]*y03[1] + x0lsb[14]*y03[2] + x0lsb[15]*y03[3] + - x0hsb[ 0]*y04[0] + x0hsb[ 1]*y04[1] + x0hsb[ 2]*y04[2] + x0hsb[ 3]*y04[3] + - x0hsb[ 4]*y05[0] + x0hsb[ 5]*y05[1] + x0hsb[ 6]*y05[2] + x0hsb[ 7]*y05[3] + - x0hsb[ 8]*y06[0] + x0hsb[ 9]*y06[1] + x0hsb[10]*y06[2] + x0hsb[11]*y06[3] + - x0hsb[12]*y07[0] + x0hsb[13]*y07[1] + x0hsb[14]*y07[2] + x0hsb[15]*y07[3] + (x0lsb[ 0] - 8)*y00[0] + (x0lsb[ 1] - 8)*y00[1] + (x0lsb[ 2] - 8)*y00[2] + (x0lsb[ 3] - 8)*y00[3] + + (x0lsb[ 4] - 8)*y01[0] + (x0lsb[ 5] - 8)*y01[1] + (x0lsb[ 6] - 8)*y01[2] + (x0lsb[ 7] - 8)*y01[3] + + (x0lsb[ 8] - 8)*y02[0] + (x0lsb[ 9] - 8)*y02[1] + (x0lsb[10] - 8)*y02[2] + (x0lsb[11] - 8)*y02[3] + + (x0lsb[12] - 8)*y03[0] + (x0lsb[13] - 8)*y03[1] + (x0lsb[14] - 8)*y03[2] + (x0lsb[15] - 8)*y03[3] + + (x0hsb[ 0] - 8)*y04[0] + (x0hsb[ 1] - 8)*y04[1] + (x0hsb[ 2] - 8)*y04[2] + (x0hsb[ 3] - 8)*y04[3] + + (x0hsb[ 4] - 8)*y05[0] + (x0hsb[ 5] - 8)*y05[1] + (x0hsb[ 6] - 8)*y05[2] + (x0hsb[ 7] - 8)*y05[3] + + (x0hsb[ 8] - 8)*y06[0] + (x0hsb[ 9] - 8)*y06[1] + (x0hsb[10] - 8)*y06[2] + (x0hsb[11] - 8)*y06[3] + + (x0hsb[12] - 8)*y07[0] + (x0hsb[13] - 8)*y07[1] + (x0hsb[14] - 8)*y07[2] + (x0hsb[15] - 8)*y07[3] ) * d; }