From a0cc3de59ad9026079b7ab6d58da1c3b0cdfdd55 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 1 Jun 2023 21:30:33 +0300 Subject: [PATCH] mtl : add f32 -> f32 cpy kernel --- examples/mtl/mtl.m | 8 ++++++++ examples/mtl/mtl.metal | 42 ++++++++++++++++++++++++++++++++++++++++++ llama.cpp | 10 +++++----- 3 files changed, 55 insertions(+), 5 deletions(-) diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index 24f9479ce..c617f4401 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -53,6 +53,9 @@ struct ggml_mtl_context { id function_cpy_f32_f16; id pipeline_cpy_f32_f16; + + id function_cpy_f32_f32; + id pipeline_cpy_f32_f32; }; // MSL code @@ -176,6 +179,10 @@ struct ggml_mtl_context * llama_mtl_init( ctx->function_cpy_f32_f16 = [ctx->library newFunctionWithName:@"kernel_cpy_f32_f16"]; ctx->pipeline_cpy_f32_f16 = [ctx->device newComputePipelineStateWithFunction:ctx->function_cpy_f32_f16 error:nil]; fprintf(stderr, "%s: loaded kernel_cpy_f32_f16: %p\n", __func__, (void *) ctx->pipeline_cpy_f32_f16); + + ctx->function_cpy_f32_f32 = [ctx->library newFunctionWithName:@"kernel_cpy_f32_f32"]; + ctx->pipeline_cpy_f32_f32 = [ctx->device newComputePipelineStateWithFunction:ctx->function_cpy_f32_f32 error:nil]; + fprintf(stderr, "%s: loaded kernel_cpy_f32_f32: %p\n", __func__, (void *) ctx->pipeline_cpy_f32_f32); } // MTLBuffer approach @@ -669,6 +676,7 @@ int llama_mtl_eval( { switch (dstt) { case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; default: GGML_ASSERT(false && "not implemented"); }; } break; diff --git a/examples/mtl/mtl.metal b/examples/mtl/mtl.metal index 32e850297..172a0fa7e 100644 --- a/examples/mtl/mtl.metal +++ b/examples/mtl/mtl.metal @@ -340,3 +340,45 @@ kernel void kernel_cpy_f32_f16( dst_data[i00] = src[0]; } } + +kernel void kernel_cpy_f32_f32( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant uint64_t & nb0, + constant uint64_t & nb1, + constant uint64_t & nb2, + constant uint64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig[2]; + const int64_t i02 = tgpig[1]; + const int64_t i01 = tgpig[0]; + + const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + + const int64_t i3 = n / (ne2*ne1*ne0); + const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); + const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; + const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + + device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + + for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { + device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + + dst_data[i00] = src[0]; + } +} diff --git a/llama.cpp b/llama.cpp index 2cf5a36fc..40292305e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1350,11 +1350,6 @@ static bool llama_eval_internal( il*n_ctx*ggml_element_size(kv_self.v)*n_embd); ggml_set_name(V, "V"); - // TODO: TMP !!!! - if (il == 0) { - ggml_set_name(V, "mtl-check"); - } - #if 1 struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V, KQ_soft_max); ggml_set_name(KQV, "KQV"); @@ -1376,6 +1371,11 @@ static bool llama_eval_internal( ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); ggml_set_name(cur, "KQV_merged_contiguous"); + // TODO: TMP !!!! + if (il == 0) { + ggml_set_name(cur, "mtl-check"); + } + // projection (no bias) cur = ggml_mul_mat(ctx0, model.layers[il].wo,