diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index ade0719ba..f13c00776 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -24,6 +24,9 @@ struct ggml_mtl_context { id function_add; id pipeline_add; + id function_mul; + id pipeline_mul; + id function_relu; id pipeline_relu; @@ -119,6 +122,10 @@ struct ggml_mtl_context * llama_mtl_init( ctx->pipeline_add = [ctx->device newComputePipelineStateWithFunction:ctx->function_add error:nil]; fprintf(stderr, "%s: loaded kernel_add: %p\n", __func__, (void *) ctx->pipeline_add); + ctx->function_mul = [ctx->library newFunctionWithName:@"kernel_mul"]; + ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil]; + fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul); + ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"]; ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil]; fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu); @@ -253,6 +260,28 @@ int llama_mtl_eval( const int64_t n = ggml_nelements(gf->nodes[i]); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_MUL: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_src1 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src1, &offs_src1); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + const int64_t ne00 = gf->nodes[i]->src0->ne[0]; + + [encoder setComputePipelineState:ctx->pipeline_mul]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; + + const int64_t n = ggml_nelements(gf->nodes[i]); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_RELU: @@ -373,7 +402,7 @@ int llama_mtl_eval( [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setBytes:&eps length:sizeof( float) atIndex:4]; const int64_t nrows = ggml_nrows(gf->nodes[i]->src0); diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 6a736446b..78dfbe011 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -42,6 +42,17 @@ kernel void kernel_add( dst[gid] = src0[gid] + src1[gid]; } +// assumption: src1 is a row +// broadcast src1 into src0 +kernel void kernel_mul( + device const float * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + uint gid[[thread_position_in_grid]]) { + dst[gid] = src0[gid] * src1[gid % ne00]; +} + kernel void kernel_relu( device const float * src0, device float * dst, diff --git a/llama.cpp b/llama.cpp index 3ee170e4c..3ddfeff01 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1263,13 +1263,13 @@ static bool llama_eval_internal( // norm { cur = ggml_rms_norm(ctx0, inpL); + + // cur = cur*attention_norm(broadcasted) + cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); // TODO: TMP !!!! if (il == 0) { ggml_set_name(cur, "mtl-check"); } - - // cur = cur*attention_norm(broadcasted) - cur = ggml_mul(ctx0, cur, model.layers[il].attention_norm); } // self-attention