diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index c617f4401..7ad1722c0 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -30,6 +30,9 @@ struct ggml_mtl_context { id function_scale; id pipeline_scale; + id function_silu; + id pipeline_silu; + id function_relu; id pipeline_relu; @@ -148,6 +151,10 @@ struct ggml_mtl_context * llama_mtl_init( ctx->pipeline_scale = [ctx->device newComputePipelineStateWithFunction:ctx->function_scale error:nil]; fprintf(stderr, "%s: loaded kernel_scale: %p\n", __func__, (void *) ctx->pipeline_scale); + ctx->function_silu = [ctx->library newFunctionWithName:@"kernel_silu"]; + ctx->pipeline_silu = [ctx->device newComputePipelineStateWithFunction:ctx->function_silu error:nil]; + fprintf(stderr, "%s: loaded kernel_silu: %p\n", __func__, (void *) ctx->pipeline_silu); + ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"]; ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil]; fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu); @@ -357,6 +364,23 @@ int llama_mtl_eval( const int64_t n = ggml_nelements(gf->nodes[i]); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SILU: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + [encoder setComputePipelineState:ctx->pipeline_silu]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + + const int64_t n = ggml_nelements(gf->nodes[i]); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_RELU: diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 172a0fa7e..2c6386990 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -58,6 +58,14 @@ kernel void kernel_scale( dst[tpig] = src0[tpig] * scale; } +kernel void kernel_silu( + device const float * src0, + device float * dst, + uint tpig[[thread_position_in_grid]]) { + float x = src0[tpig]; + dst[tpig] = x / (1.0f + exp(-x)); +} + kernel void kernel_relu( device const float * src0, device float * dst, diff --git a/llama.cpp b/llama.cpp index 40292305e..52f91ae29 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1371,11 +1371,6 @@ static bool llama_eval_internal( ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); ggml_set_name(cur, "KQV_merged_contiguous"); - // TODO: TMP !!!! - if (il == 0) { - ggml_set_name(cur, "mtl-check"); - } - // projection (no bias) cur = ggml_mul_mat(ctx0, model.layers[il].wo, @@ -1407,6 +1402,11 @@ static bool llama_eval_internal( // SILU activation cur = ggml_silu(ctx0, cur); + // TODO: TMP !!!! + if (il == 0) { + ggml_set_name(cur, "mtl-check"); + } + cur = ggml_mul(ctx0, cur, tmp); cur = ggml_mul_mat(ctx0,