diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 7ad1722c0..2de105640 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -27,6 +27,10 @@ struct ggml_mtl_context { id function_mul; id pipeline_mul; + // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast + id function_mul_row; + id pipeline_mul_row; + id function_scale; id pipeline_scale; @@ -147,6 +151,10 @@ struct ggml_mtl_context * llama_mtl_init( ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil]; fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul); + ctx->function_mul_row = [ctx->library newFunctionWithName:@"kernel_mul_row"]; + ctx->pipeline_mul_row = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul_row error:nil]; + fprintf(stderr, "%s: loaded kernel_mul_row: %p\n", __func__, (void *) ctx->pipeline_mul_row); + ctx->function_scale = [ctx->library newFunctionWithName:@"kernel_scale"]; ctx->pipeline_scale = [ctx->device newComputePipelineStateWithFunction:ctx->function_scale error:nil]; fprintf(stderr, "%s: loaded kernel_scale: %p\n", __func__, (void *) ctx->pipeline_scale); @@ -336,7 +344,14 @@ int llama_mtl_eval( const int64_t ne00 = gf->nodes[i]->src0->ne[0]; - [encoder setComputePipelineState:ctx->pipeline_mul]; + const int64_t ne10 = gf->nodes[i]->src1->ne[0]; + + if (ggml_nelements(gf->nodes[i]->src1) == ne10) { + // src1 is a row + [encoder setComputePipelineState:ctx->pipeline_mul_row]; + } else { + [encoder setComputePipelineState:ctx->pipeline_mul]; + } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 2c6386990..9ab51963f 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -39,9 +39,17 @@ kernel void kernel_add( dst[tpig] = src0[tpig] + src1[tpig]; } +kernel void kernel_mul( + device const float * src0, + device const float * src1, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * src1[tpig]; +} + // assumption: src1 is a row // broadcast src1 into src0 -kernel void kernel_mul( +kernel void kernel_mul_row( device const float * src0, device const float * src1, device float * dst, diff --git a/llama.cpp b/llama.cpp index 52f91ae29..81d998c18 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1402,13 +1402,13 @@ static bool llama_eval_internal( // SILU activation cur = ggml_silu(ctx0, cur); + cur = ggml_mul(ctx0, cur, tmp); + // TODO: TMP !!!! if (il == 0) { ggml_set_name(cur, "mtl-check"); } - cur = ggml_mul(ctx0, cur, tmp); - cur = ggml_mul_mat(ctx0, model.layers[il].w2, cur);