From d0f75193382cac7dbcb250edab8e7943f79d862d Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Mar 2023 22:58:00 +0200 Subject: [PATCH] Fix KV cache size for F32 --- llama.cpp | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/llama.cpp b/llama.cpp index 071b54e47..9d48ccd4c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -230,11 +230,16 @@ struct llama_context { static bool kv_cache_init( const struct llama_hparams & hparams, - const size_t mem_bytes, struct llama_kv_cache & cache, ggml_type wtype, int n_ctx) { - cache.buf.resize(mem_bytes); + const int n_embd = hparams.n_embd; + const int n_layer = hparams.n_layer; + + const int n_mem = n_layer*n_ctx; + const int n_elements = n_embd*n_mem; + + cache.buf.resize(2*n_elements*ggml_type_size(wtype) + 2u*MB); struct ggml_init_params params; params.mem_size = cache.buf.size(); @@ -247,12 +252,6 @@ static bool kv_cache_init( return false; } - const int n_embd = hparams.n_embd; - const int n_layer = hparams.n_layer; - - const int n_mem = n_layer*n_ctx; - const int n_elements = n_embd*n_mem; - cache.k = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); cache.v = ggml_new_tensor_1d(cache.ctx, wtype, n_elements); @@ -485,6 +484,8 @@ static bool llama_model_load( // print memory requirements { + const size_t scale = memory_type == GGML_TYPE_F32 ? 2 : 1; + // this is the total memory required to run the inference const size_t mem_required = ctx_size + @@ -494,7 +495,7 @@ static bool llama_model_load( // this is the memory required by one llama_state const size_t mem_required_state = - MEM_REQ_KV_SELF.at (model.type); + scale*MEM_REQ_KV_SELF.at(model.type); fprintf(stderr, "%s: mem required = %7.2f MB (+ %7.2f MB per state)\n", __func__, mem_required / 1024.0 / 1024.0, mem_required_state / 1024.0 / 1024.0); @@ -1634,7 +1635,7 @@ struct llama_context * llama_init_from_file( // reserve memory for context buffers { - if (!kv_cache_init(ctx->model.hparams, MEM_REQ_KV_SELF.at(ctx->model.type), ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) { + if (!kv_cache_init(ctx->model.hparams, ctx->model.kv_self, memory_type, ctx->model.hparams.n_ctx)) { fprintf(stderr, "%s: kv_cache_init() failed for self-attention cache\n", __func__); llama_free(ctx); return nullptr;