From 2f4e9d19cce4ca9dc0a37d9734df17fe8d03dd49 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Jun 2023 21:52:11 +0300 Subject: [PATCH] mtl : plug Metal inference into llama.cpp (very quick-n-dirty) --- CMakeLists.txt | 15 ++++- examples/common.cpp | 4 ++ examples/common.h | 1 + examples/mtl/mtl.h | 2 + examples/mtl/mtl.m | 151 ++++++++++++++++++++++++++++++-------------- llama.cpp | 73 +++++++++++++++++---- llama.h | 1 + 7 files changed, 186 insertions(+), 61 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 21f4ec9dd..bc23c2c5b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -384,11 +384,22 @@ endif() add_library(llama llama.cpp llama.h - llama-util.h) + llama-util.h + examples/mtl/mtl.h # TODO: METAL TMP + examples/mtl/mtl.m # TODO: METAL TMP + ) target_include_directories(llama PUBLIC .) target_compile_features(llama PUBLIC cxx_std_11) # don't bump -target_link_libraries(llama PRIVATE ggml ${LLAMA_EXTRA_LIBS}) +target_link_libraries(llama PRIVATE + ggml + ${LLAMA_EXTRA_LIBS} + ${FOUNDATION_LIBRARY} # TODO: METAL TMP + ${METAL_FRAMEWORK} # TODO: METAL TMP + ${METALKIT_FRAMEWORK} # TODO: METAL TMP + ${METALPERFORMANCE_FRAMEWORK} # TODO: METAL TMP + ) +target_compile_definitions(llama PRIVATE LLAMA_MTL_NDEBUG) # TODO: METAL TMP if (BUILD_SHARED_LIBS) set_target_properties(llama PROPERTIES POSITION_INDEPENDENT_CODE ON) diff --git a/examples/common.cpp b/examples/common.cpp index b5810f28f..53e9200fa 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -301,6 +301,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.mem_test = true; } else if (arg == "--export") { params.export_cgraph = true; + } else if (arg == "--import") { + params.import_cgraph = true; } else if (arg == "--verbose-prompt") { params.verbose_prompt = true; } else if (arg == "-r" || arg == "--reverse-prompt") { @@ -441,6 +443,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { #endif fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n"); + fprintf(stderr, " --import import a computation graph from 'llama.ggml'\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); @@ -490,6 +493,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) { lparams.use_mlock = params.use_mlock; lparams.logits_all = params.perplexity; lparams.embedding = params.embedding; + lparams.cgraph = params.import_cgraph; llama_context * lctx = llama_init_from_file(params.model.c_str(), lparams); diff --git a/examples/common.h b/examples/common.h index 66bdeb5e9..c7d4d6e0e 100644 --- a/examples/common.h +++ b/examples/common.h @@ -72,6 +72,7 @@ struct gpt_params { bool use_mlock = false; // use mlock to keep model in memory bool mem_test = false; // compute maximum memory usage bool export_cgraph = false; // export the computation graph + bool import_cgraph = false; // import a computation graph bool verbose_prompt = false; // print prompt tokens before generation }; diff --git a/examples/mtl/mtl.h b/examples/mtl/mtl.h index a6a336eaa..f381756d4 100644 --- a/examples/mtl/mtl.h +++ b/examples/mtl/mtl.h @@ -25,6 +25,8 @@ int llama_mtl_eval( int n_tokens, int n_past); +float * llama_mtl_get_logits(struct ggml_mtl_context * ctx); + #ifdef __cplusplus } #endif diff --git a/examples/mtl/mtl.m b/examples/mtl/mtl.m index e4839626e..4ac5dac20 100644 --- a/examples/mtl/mtl.m +++ b/examples/mtl/mtl.m @@ -6,11 +6,19 @@ #import #import +#ifdef LLAMA_MTL_NDEBUG +#define mtl_printf(...) +#else +#define mtl_printf(...) fprintf(stderr, __VA_ARGS__) +#endif + struct ggml_mtl_context { struct ggml_context * ctx_data; struct ggml_context * ctx_eval; struct ggml_context * ctx_work; + float * logits; + id device; id queue; id library; @@ -274,7 +282,44 @@ int llama_mtl_eval( const int * tokens, int n_tokens, int n_past) { - fprintf(stderr, "%s: evaluating, n_tokens = %d, n_past = %d\n", __func__, n_tokens, n_past); + mtl_printf("%s: evaluating, n_tokens = %d, n_past = %d\n", __func__, n_tokens, n_past); + + // adjust dynamic shapes + // TODO: wrong ... + //{ + // struct ggml_tensor * t = ggml_graph_get_tensor(gf, "embd"); + // t->ne[0] = n_tokens; + //} + //{ + // struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Qpre"); + // t->src0->ne[2] = n_tokens; + //} + //{ + // struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Kpre"); + // t->src0->ne[2] = n_tokens; + //} + //{ + // struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Vcur"); + // t->ne[0] = n_tokens; + //} + //{ + // struct ggml_tensor * k = ggml_graph_get_tensor(gf, "k"); + // struct ggml_tensor * v = ggml_graph_get_tensor(gf, "v"); + // k->ne[0] = n_tokens*v->ne[1]; + // v->ne[0] = n_tokens; + //} + //{ + // struct ggml_tensor * t = ggml_graph_get_tensor(gf, "Q"); + // t->ne[1] = n_tokens; + //} + //{ + // struct ggml_tensor * t = ggml_graph_get_tensor(gf, "K"); + // t->ne[1] = n_past + n_tokens; + //} + //{ + // struct ggml_tensor * t = ggml_graph_get_tensor(gf, "KQV_merged_contiguous"); + // t->src1->ne[1] = n_tokens; + //} struct ggml_tensor * input = ggml_graph_get_tensor(gf, "embd"); memcpy(input->data, tokens, n_tokens * sizeof(int)); @@ -296,7 +341,7 @@ int llama_mtl_eval( } for (int i = 0; i < gf->n_nodes; ++i) { - //fprintf(stderr, "%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + //mtl_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); struct ggml_tensor * src0 = gf->nodes[i]->src0; struct ggml_tensor * src1 = gf->nodes[i]->src1; @@ -340,7 +385,21 @@ int llama_mtl_eval( id id_src1 = src1 ? llama_mtl_get_buffer(ctx, src1, &offs_src1) : nil; id id_dst = dst ? llama_mtl_get_buffer(ctx, dst, &offs_dst) : nil; - switch (gf->nodes[i]->op) { + //mtl_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op)); + //if (src0) { + // mtl_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, + // ggml_is_contiguous(src0), src0->name); + //} + //if (src1) { + // mtl_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12, + // ggml_is_contiguous(src1), src1->name); + //} + //if (dst) { + // mtl_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2, + // dst->name); + //} + + switch (dst->op) { case GGML_OP_RESHAPE: case GGML_OP_VIEW: case GGML_OP_TRANSPOSE: @@ -359,7 +418,7 @@ int llama_mtl_eval( [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - const int64_t n = ggml_nelements(gf->nodes[i]); + const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -369,7 +428,7 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - if (ggml_nelements(gf->nodes[i]->src1) == ne10) { + if (ggml_nelements(src1) == ne10) { // src1 is a row [encoder setComputePipelineState:ctx->pipeline_mul_row]; } else { @@ -380,7 +439,7 @@ int llama_mtl_eval( [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - const int64_t n = ggml_nelements(gf->nodes[i]); + const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -390,14 +449,14 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - const float scale = *(const float *) gf->nodes[i]->src1->data; + const float scale = *(const float *) src1->data; [encoder setComputePipelineState:ctx->pipeline_scale]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - const int64_t n = ggml_nelements(gf->nodes[i]); + const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -411,7 +470,7 @@ int llama_mtl_eval( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(gf->nodes[i]); + const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -425,7 +484,7 @@ int llama_mtl_eval( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(gf->nodes[i]); + const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -464,16 +523,11 @@ int llama_mtl_eval( } break; case GGML_OP_MUL_MAT: { - //fprintf(stderr, "mul_mat: src0 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src0t), ne00, ne01, ne02, ggml_is_contiguous(gf->nodes[i]->src0)); - //fprintf(stderr, "mul_mat: src1 - %s[%lld, %lld, %lld], %d\n", ggml_type_name(src1t), ne10, ne11, ne12, ggml_is_contiguous(gf->nodes[i]->src1)); - //fprintf(stderr, "mul_mat: dst - %s[%lld, %lld, %lld]\n", ggml_type_name(dstt), ne0, ne1, ne2); - //fprintf(stderr, "mul_mat: %s * %s -> %s\n", ggml_type_name(src0t), ggml_type_name(src1t), ggml_type_name(dstt)); - GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne02 == ne12); - if (ggml_is_contiguous(gf->nodes[i]->src0) && - ggml_is_contiguous(gf->nodes[i]->src1) && + if (ggml_is_contiguous(src0) && + ggml_is_contiguous(src1) && (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) { if (encoder != nil) { @@ -486,13 +540,13 @@ int llama_mtl_eval( // for F32 x F32 we use MPS MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor - matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:gf->nodes[i]->src0->nb[1] dataType:src0dt]; + matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt]; MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor - matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:gf->nodes[i]->src1->nb[1] dataType:src1dt]; + matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt]; MPSMatrixDescriptor * desc = [MPSMatrixDescriptor - matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:gf->nodes[i]->nb[1] dataType:MPSDataTypeFloat32]; + matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32]; MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] initWithDevice:ctx->device transposeLeft:false transposeRight:true @@ -573,22 +627,22 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - switch (gf->nodes[i]->src0->type) { + switch (src0->type) { case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; default: { // not implemented - fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op), ggml_type_name(gf->nodes[i]->src0->type)); + fprintf(stderr, "%s: node %3d, op = %8s, type = %8s not implemented\n", __func__, i, ggml_op_name(dst->op), ggml_type_name(src0->type)); } } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&(gf->nodes[i]->src0->ne[0]) length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&(gf->nodes[i]->src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; - [encoder setBytes:&(gf->nodes[i]->nb[1]) length:sizeof(uint64_t) atIndex:5]; + [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3]; + [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4]; + [encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5]; - const int64_t n = ggml_nelements(gf->nodes[i]->src1); + const int64_t n = ggml_nelements(src1); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; @@ -610,7 +664,7 @@ int llama_mtl_eval( [encoder setBytes:&eps length:sizeof( float) atIndex:4]; [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; - const int64_t nrows = ggml_nrows(gf->nodes[i]->src0); + const int64_t nrows = ggml_nrows(src0); [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -620,12 +674,12 @@ int llama_mtl_eval( encoder = [command_buffer computeCommandEncoder]; } - const int n_dims = ((int32_t *) gf->nodes[i]->src1->data)[1]; - const int mode = ((int32_t *) gf->nodes[i]->src1->data)[2]; + const int n_dims = ((int32_t *) src1->data)[1]; + const int mode = ((int32_t *) src1->data)[2]; - //fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); - //fprintf(stderr, "rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); - //fprintf(stderr, "rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode); + //mtl_printf("rope: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); + //mtl_printf("rope: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); + //mtl_printf("rope: n_past = %d, n_dims = %d, mode = %d\n", n_past, n_dims, mode); [encoder setComputePipelineState:ctx->pipeline_rope]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -660,11 +714,11 @@ int llama_mtl_eval( const int nth = 32; - //fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); - //fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03); - //fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); - //fprintf(stderr, "cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3); - //fprintf(stderr, "cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt)); + //mtl_printf("cpy: %lld x %lld x %lld x %lld\n", ne00, ne01, ne02, ne03); + //mtl_printf("cpy: %lld x %lld x %lld x %lld\n", nb00, nb01, nb02, nb03); + //mtl_printf("cpy: %lld x %lld x %lld x %lld\n", ne0, ne1, ne2, ne3); + //mtl_printf("cpy: %lld x %lld x %lld x %lld\n", nb0, nb1, nb2, nb3); + //mtl_printf("cpy: %s -> %s\n", ggml_type_name(src0t), ggml_type_name(dstt)); switch (src0t) { case GGML_TYPE_F32: @@ -700,7 +754,7 @@ int llama_mtl_eval( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; default: - fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(gf->nodes[i]->op)); + fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); GGML_ASSERT(false); return -1; } @@ -708,7 +762,7 @@ int llama_mtl_eval( // extract results from the GPU { - fprintf(stderr, "%s: extract results from the GPU\n", __func__); + mtl_printf("%s: extract results from the GPU\n", __func__); if (encoder != nil) { [encoder endEncoding]; @@ -730,18 +784,19 @@ int llama_mtl_eval( { const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime]; - printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0); + mtl_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0); } - // TODO - const float * logits = ctx->out.contents; + ctx->logits = ctx->out.contents; + + const float * logits = ctx->logits; #if 1 - printf("logits: "); + mtl_printf("logits: "); for (int i = 0; i < 100; i++) { - printf("%8.4f ", logits[i]); + mtl_printf("%8.4f ", logits[i]); } - printf("\n"); + mtl_printf("\n"); double sum = 0.0; int imax = 0; double vmax = -INFINITY; @@ -752,7 +807,7 @@ int llama_mtl_eval( imax = i; } } - printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax); + mtl_printf("sum: %f, imax = %d, vmax = %f\n", sum, imax, vmax); #endif //{ @@ -801,3 +856,7 @@ int llama_mtl_eval( return 0; } + +float * llama_mtl_get_logits(struct ggml_mtl_context * ctx) { + return ctx->logits; +} diff --git a/llama.cpp b/llama.cpp index 9a8bf9df7..93ca233a9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9,6 +9,9 @@ #include "llama-util.h" #include "llama.h" +// METAL +#include "examples/mtl/mtl.h" + #include "ggml.h" #ifdef GGML_USE_CUBLAS #include "ggml-cuda.h" @@ -238,6 +241,10 @@ struct llama_context { llama_ctx_buffer buf_compute; llama_ctx_buffer buf_scratch[LLAMA_MAX_SCRATCH_BUFFERS]; + // METAL + ggml_mtl_context * mtl_ctx = NULL; + ggml_cgraph mtl_gf; + int buf_last = 0; size_t buf_max_size[LLAMA_MAX_SCRATCH_BUFFERS] = { 0 }; @@ -836,6 +843,7 @@ struct llama_context_params llama_context_default_params() { /*.use_mmap =*/ true, /*.use_mlock =*/ false, /*.embedding =*/ false, + /*.cgraph =*/ false, /*.progress_callback =*/ nullptr, /*.progress_callback_user_data =*/ nullptr, }; @@ -1270,8 +1278,14 @@ static bool llama_eval_internal( //struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, x, n_embd/n_head, n_head, N), n_past, n_rot, 0); // compute Q and K and RoPE them - struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); - struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0); + + struct ggml_tensor * Qpre = ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wq, cur), n_embd/n_head, n_head, N); + struct ggml_tensor * Kpre = ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wk, cur), n_embd/n_head, n_head, N); + ggml_set_name(Qpre, "Qpre"); + ggml_set_name(Kpre, "Kpre"); + + struct ggml_tensor * Qcur = ggml_rope_inplace(ctx0, Qpre, n_past, n_rot, 0); + struct ggml_tensor * Kcur = ggml_rope_inplace(ctx0, Kpre, n_past, n_rot, 0); ggml_set_name(Qcur, "Qcur"); ggml_set_name(Kcur, "Kcur"); @@ -1279,22 +1293,19 @@ static bool llama_eval_internal( { // compute the transposed [N, n_embd] V matrix struct ggml_tensor * Vcur = ggml_transpose(ctx0, ggml_reshape_2d(ctx0, ggml_mul_mat(ctx0, model.layers[il].wv, cur), n_embd, N)); + ggml_set_name(Vcur, "Vcur"); struct ggml_tensor * k = ggml_view_1d(ctx0, kv_self.k, N*n_embd, (ggml_element_size(kv_self.k)*n_embd)*(il*n_ctx + n_past)); struct ggml_tensor * v = ggml_view_2d(ctx0, kv_self.v, N, n_embd, ( n_ctx)*ggml_element_size(kv_self.v), (il*n_ctx)*ggml_element_size(kv_self.v)*n_embd + n_past*ggml_element_size(kv_self.v)); - //struct ggml_tensor * t = ggml_cpy(ctx0, Vcur, v); - //// TODO: TMP !!!! - //if (il == 0) { - // ggml_set_name(t, "mtl-check"); - //} + ggml_set_name(k, "k"); + ggml_set_name(v, "v"); // important: storing RoPE-ed version of K in the KV cache! ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); - //ggml_build_forward_expand(&gf, t); } struct ggml_tensor * Q = @@ -2391,9 +2402,25 @@ struct llama_context * llama_init_from_file( ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; + // METAL + if (params.cgraph) { + params.vocab_only = true; + + // load the compute graph + struct ggml_context * ctx_data = NULL; + struct ggml_context * ctx_eval = NULL; + + struct ggml_cgraph gf = ggml_graph_import("llama.ggml", &ctx_data, &ctx_eval); + gf.n_threads = 1; + + // this allocates all Metal resources and memory buffers + ctx->mtl_ctx = llama_mtl_init(ctx_data, ctx_eval, NULL, &gf); + ctx->mtl_gf = gf; + } + if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type, - params.use_mmap, params.use_mlock, params.vocab_only, - params.progress_callback, params.progress_callback_user_data)) { + params.use_mmap, params.use_mlock, params.vocab_only, + params.progress_callback, params.progress_callback_user_data)) { fprintf(stderr, "%s: failed to load model\n", __func__); llama_free(ctx); return nullptr; @@ -2411,7 +2438,11 @@ struct llama_context * llama_init_from_file( const size_t memory_size = ggml_nbytes(ctx->model.kv_self.k) + ggml_nbytes(ctx->model.kv_self.v); fprintf(stderr, "%s: kv self size = %7.2f MB\n", __func__, memory_size / 1024.0 / 1024.0); } + } + // METAL + // TODO: changed the behavior here for vocab_only -- reconsider implications later + { const auto & hparams = ctx->model.hparams; // resized during inference @@ -3046,9 +3077,25 @@ int llama_eval( int n_tokens, int n_past, int n_threads) { - if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) { - fprintf(stderr, "%s: failed to eval\n", __func__); - return 1; + // METAL + if (ctx->mtl_ctx) { + llama_mtl_eval(ctx->mtl_ctx, &ctx->mtl_gf, tokens, n_tokens, n_past); + + const float * logits = llama_mtl_get_logits(ctx->mtl_ctx); + + // extract logits + { + const int n_vocab = ctx->model.hparams.n_vocab; + auto & logits_out = ctx->logits; + + logits_out.resize(n_vocab); + memcpy(logits_out.data(), logits, sizeof(float)*n_vocab); + } + } else { + if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, nullptr)) { + fprintf(stderr, "%s: failed to eval\n", __func__); + return 1; + } } // get a more accurate load time, upon first eval diff --git a/llama.h b/llama.h index 3ba0775bd..faaca2637 100644 --- a/llama.h +++ b/llama.h @@ -75,6 +75,7 @@ extern "C" { bool use_mmap; // use mmap if possible bool use_mlock; // force system to keep model in RAM bool embedding; // embedding mode only + bool cgraph; // try to load computation graph from "llama.ggml" (METAL) // called with a progress value between 0 and 1, pass NULL to disable llama_progress_callback progress_callback;