From 77d662faa5c3d9a8b00dfb52dde884603bf153a4 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 25 Jul 2023 20:36:32 +0200 Subject: [PATCH] llama.cpp : allocate graph in the context ggml-ci --- ggml.c | 2 +- llama.cpp | 28 ++++++++++++++-------------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/ggml.c b/ggml.c index d4abfcc5f..186b0fd60 100644 --- a/ggml.c +++ b/ggml.c @@ -4212,7 +4212,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) { } size_t ggml_tensor_overhead(void) { - return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE; // REVIEW: i don't think we need to 16 here because GGML_OBJECT_SIZE and GGML_TENSOR_SIZE are already aligned + return GGML_OBJECT_SIZE + GGML_TENSOR_SIZE; // REVIEW: i don't think we need to add 16 here because GGML_OBJECT_SIZE and GGML_TENSOR_SIZE are already aligned } bool ggml_is_transposed(const struct ggml_tensor * tensor) { diff --git a/llama.cpp b/llama.cpp index 30d4b0a6e..e649dce52 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1424,7 +1424,7 @@ static bool llama_eval_internal( struct ggml_context * ctx0 = ggml_init(params); - ggml_cgraph gf = {}; + ggml_cgraph * gf = ggml_new_graph(ctx0); // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance @@ -1541,8 +1541,8 @@ static bool llama_eval_internal( ggml_set_name(v, "v"); // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Kcur, k)); - ggml_build_forward_expand(&gf, ggml_cpy(ctx0, Vcur, v)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Kcur, k)); + ggml_build_forward_expand(gf, ggml_cpy(ctx0, Vcur, v)); } struct ggml_tensor * Q = @@ -1712,7 +1712,7 @@ static bool llama_eval_internal( //cur = ggml_soft_max_inplace(ctx0, cur); // run the computation - ggml_build_forward_expand(&gf, cur); + ggml_build_forward_expand(gf, cur); // fprintf(stderr, "graph build time: %.3f ms (%d nodes, %d leafs)\n", (ggml_time_us() - t_start_us)/1000.0, gf.n_nodes, gf.n_leafs); @@ -1723,10 +1723,10 @@ static bool llama_eval_internal( #ifdef GGML_USE_METAL if (lctx.ctx_metal && N == 1) { if (!ggml_metal_if_optimized(lctx.ctx_metal)) { - ggml_metal_graph_find_concurrency(lctx.ctx_metal,&gf); + ggml_metal_graph_find_concurrency(lctx.ctx_metal, gf); } ggml_metal_set_n_cb (lctx.ctx_metal, n_threads); - ggml_metal_graph_compute(lctx.ctx_metal, &gf); + ggml_metal_graph_compute(lctx.ctx_metal, gf); ggml_metal_get_tensor (lctx.ctx_metal, cur); } else { // IMPORTANT: @@ -1745,34 +1745,34 @@ static bool llama_eval_internal( ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v); } - ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); + ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); } #else - ggml_graph_compute_helper(lctx.work_buffer, &gf, n_threads); + ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads); #endif #if GGML_USE_MPI - ggml_mpi_graph_compute_post(lctx.ctx_mpi, &gf, n_layer); + ggml_mpi_graph_compute_post(lctx.ctx_mpi, gf, n_layer); #endif // update kv token count lctx.kv_self.n = n_past + N; - struct ggml_tensor * res = gf.nodes[gf.n_nodes - 1]; + struct ggml_tensor * res = gf->nodes[gf->n_nodes - 1]; if (cgraph_fname) { - ggml_graph_export(&gf, cgraph_fname); + ggml_graph_export(gf, cgraph_fname); } #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes) // requires GGML_PERF to be defined - ggml_graph_print(&gf); + ggml_graph_print(gf); #endif // plot the computation graph in dot format (for debugging purposes) //if (n_past%100 == 0) { - // ggml_graph_dump_dot(&gf, NULL, "llama.dot"); + // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} // extract logits @@ -3177,7 +3177,7 @@ struct llama_context * llama_new_context_with_model( ctx->embedding.resize(hparams.n_embd); } - ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type)); + ctx->buf_compute.resize(MEM_REQ_EVAL().at(ctx->model.type) + ggml_graph_overhead()); ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0(hparams.n_ctx).at(ctx->model.type)); ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1().at(ctx->model.type));