From 0f1c580860e2acbee7c095b113256f69e93869b5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Jun 2023 19:52:32 +0300 Subject: [PATCH] mtl : add scale kernel --- examples/mtl/mtl.m | 27 +++++++++++++++++++++++++++ examples/mtl/mtl.metal | 8 ++++++++ llama.cpp | 8 ++++---- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 1327de0b4..8f55f8467 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -27,6 +27,9 @@ struct ggml_mtl_context { id function_mul; id pipeline_mul; + id function_scale; + id pipeline_scale; + id function_relu; id pipeline_relu; @@ -135,6 +138,10 @@ struct ggml_mtl_context * llama_mtl_init( ctx->pipeline_mul = [ctx->device newComputePipelineStateWithFunction:ctx->function_mul error:nil]; fprintf(stderr, "%s: loaded kernel_mul: %p\n", __func__, (void *) ctx->pipeline_mul); + ctx->function_scale = [ctx->library newFunctionWithName:@"kernel_scale"]; + ctx->pipeline_scale = [ctx->device newComputePipelineStateWithFunction:ctx->function_scale error:nil]; + fprintf(stderr, "%s: loaded kernel_scale: %p\n", __func__, (void *) ctx->pipeline_scale); + ctx->function_relu = [ctx->library newFunctionWithName:@"kernel_relu"]; ctx->pipeline_relu = [ctx->device newComputePipelineStateWithFunction:ctx->function_relu error:nil]; fprintf(stderr, "%s: loaded kernel_relu: %p\n", __func__, (void *) ctx->pipeline_relu); @@ -310,6 +317,26 @@ int llama_mtl_eval( const int64_t n = ggml_nelements(gf->nodes[i]); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; + case GGML_OP_SCALE: + { + if (encoder == nil) { + encoder = [command_buffer computeCommandEncoder]; + } + + id id_src0 = llama_mtl_get_buffer(ctx, gf->nodes[i]->src0, &offs_src0); + id id_dst = llama_mtl_get_buffer(ctx, gf->nodes[i], &offs_dst); + + const float scale = *(const float *) gf->nodes[i]->src1->data; + + [encoder setComputePipelineState:ctx->pipeline_scale]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; + + const int64_t n = ggml_nelements(gf->nodes[i]); + [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_RELU: diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 7e5c3aad4..b132be15e 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -53,6 +53,14 @@ kernel void kernel_mul( dst[tpig] = src0[tpig] * src1[tpig % ne00]; } +kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + kernel void kernel_relu( device const float * src0, device float * dst, diff --git a/llama.cpp b/llama.cpp index f6d93bd93..28d489016 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1324,10 +1324,6 @@ static bool llama_eval_internal( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); ggml_set_name(KQ, "KQ"); - // TODO: TMP !!!! - if (il == 0) { - ggml_set_name(KQ, "mtl-check"); - } // KQ_scaled = KQ / sqrt(n_embd/n_head) struct ggml_tensor * KQ_scale = ggml_new_f32(ctx0, 1.0f/sqrtf(float(n_embd)/n_head)); @@ -1336,6 +1332,10 @@ static bool llama_eval_internal( // KQ_scaled shape [n_past + N, N, n_head, 1] struct ggml_tensor * KQ_scaled = ggml_scale_inplace(ctx0, KQ, KQ_scale); ggml_set_name(KQ_scaled, "KQ_scaled"); + // TODO: TMP !!!! + if (il == 0) { + ggml_set_name(KQ_scaled, "mtl-check"); + } // KQ_masked = mask_past(KQ_scaled) struct ggml_tensor * KQ_masked = ggml_diag_mask_inf_inplace(ctx0, KQ_scaled, n_past);