From ea60d2193a2346d24939d433dd7da54bfb0afd6e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 24 Mar 2023 21:41:47 +0200 Subject: [PATCH] Simpler scratch buffer usage --- llama.cpp | 104 ++++++++++-------------------------------------------- 1 file changed, 18 insertions(+), 86 deletions(-) diff --git a/llama.cpp b/llama.cpp index 7c5b09a5d..b5684d6fa 100644 --- a/llama.cpp +++ b/llama.cpp @@ -48,45 +48,17 @@ static const size_t MB = 1024*1024; // needs modifications in ggml static const std::map MEM_REQ_SCRATCH0 = { - { MODEL_7B, 128ull*MB }, - { MODEL_13B, 128ull*MB }, - { MODEL_30B, 128ull*MB }, - { MODEL_65B, 128ull*MB }, + { MODEL_7B, 512ull*MB }, + { MODEL_13B, 512ull*MB }, + { MODEL_30B, 512ull*MB }, + { MODEL_65B, 512ull*MB }, }; static const std::map MEM_REQ_SCRATCH1 = { - { MODEL_7B, 128ull*MB }, - { MODEL_13B, 128ull*MB }, - { MODEL_30B, 128ull*MB }, - { MODEL_65B, 128ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH2 = { - { MODEL_7B, 32ull*MB }, - { MODEL_13B, 32ull*MB }, - { MODEL_30B, 32ull*MB }, - { MODEL_65B, 32ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH3 = { - { MODEL_7B, 32ull*MB }, - { MODEL_13B, 32ull*MB }, - { MODEL_30B, 32ull*MB }, - { MODEL_65B, 32ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH4 = { - { MODEL_7B, 128ull*MB }, - { MODEL_13B, 128ull*MB }, - { MODEL_30B, 128ull*MB }, - { MODEL_65B, 128ull*MB }, -}; - -static const std::map MEM_REQ_SCRATCH5 = { - { MODEL_7B, 4ull*MB }, - { MODEL_13B, 4ull*MB }, - { MODEL_30B, 4ull*MB }, - { MODEL_65B, 4ull*MB }, + { MODEL_7B, 512ull*MB }, + { MODEL_13B, 512ull*MB }, + { MODEL_30B, 512ull*MB }, + { MODEL_65B, 512ull*MB }, }; // 2*n_embd*n_ctx*n_layer*sizeof(float16) @@ -98,10 +70,10 @@ static const std::map MEM_REQ_KV_SELF = { }; static const std::map MEM_REQ_EVAL = { - { MODEL_7B, 256ull*MB }, - { MODEL_13B, 256ull*MB }, - { MODEL_30B, 256ull*MB }, - { MODEL_65B, 256ull*MB }, + { MODEL_7B, 128ull*MB }, + { MODEL_13B, 128ull*MB }, + { MODEL_30B, 128ull*MB }, + { MODEL_65B, 128ull*MB }, }; // default hparams (LLaMA 7B) @@ -516,10 +488,6 @@ static bool llama_model_load( ctx_size + MEM_REQ_SCRATCH0.at(model.type) + MEM_REQ_SCRATCH1.at(model.type) + - MEM_REQ_SCRATCH2.at(model.type) + - MEM_REQ_SCRATCH3.at(model.type) + - MEM_REQ_SCRATCH4.at(model.type) + - MEM_REQ_SCRATCH5.at(model.type) + MEM_REQ_EVAL.at (model.type); // this is the memory required by one llama_state @@ -864,8 +832,6 @@ static bool llama_eval_internal( struct ggml_tensor * embd = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N); memcpy(embd->data, tokens, N*ggml_element_size(embd)); - lctx.use_buf(ctx0, 3); - struct ggml_tensor * inpL = ggml_get_rows(ctx0, model.tok_embeddings, embd); for (int il = 0; il < n_layer; ++il) { @@ -873,10 +839,10 @@ static bool llama_eval_internal( struct ggml_tensor * cur; + lctx.use_buf(ctx0, 0); + // norm { - lctx.use_buf(ctx0, 0); - cur = ggml_rms_norm(ctx0, inpL); // cur = attention_norm*cur @@ -887,9 +853,6 @@ static bool llama_eval_internal( // self-attention { - // needed due to ggml_rope creating a "parameters" tensor - lctx.use_buf(ctx0, 4); - struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur); struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur); struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); @@ -926,8 +889,6 @@ static bool llama_eval_internal( // K * Q struct ggml_tensor * KQ = ggml_mul_mat(ctx0, K, Q); - lctx.use_buf(ctx0, 1); - // KQ_scaled = KQ / sqrt(n_embd/n_head) struct ggml_tensor * KQ_scaled = ggml_scale(ctx0, @@ -935,16 +896,12 @@ static bool llama_eval_internal( ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)) ); - lctx.use_buf(ctx0, 0); - // KQ_masked = mask_past(KQ_scaled) struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past); // KQ = soft_max(KQ_masked) struct ggml_tensor * KQ_soft_max = ggml_soft_max(ctx0, KQ_masked); - lctx.use_buf(ctx0, 1); - // V_trans = Vmem.view(n_embd/n_head, n_head, n_past + N).permute(1, 2, 0, 3).contiguous() struct ggml_tensor * V_trans = ggml_cpy(ctx0, @@ -958,8 +915,6 @@ static bool llama_eval_internal( // KQV = transpose(V) * KQ_soft_max struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max); - lctx.use_buf(ctx0, 0); - // KQV_merged = KQV.permute(0, 2, 1, 3) struct ggml_tensor * KQV_merged = ggml_permute(ctx0, KQV, 0, 2, 1, 3); @@ -968,15 +923,13 @@ static bool llama_eval_internal( KQV_merged, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, N)); - lctx.use_buf(ctx0, 1); - // projection (no bias) cur = ggml_mul_mat(ctx0, model.layers[il].wo, cur); } - lctx.use_buf(ctx0, 2); + lctx.use_buf(ctx0, 1); struct ggml_tensor * inpFF = ggml_add(ctx0, cur, inpSA); @@ -984,20 +937,14 @@ static bool llama_eval_internal( { // norm { - lctx.use_buf(ctx0, 0); - cur = ggml_rms_norm(ctx0, inpFF); - lctx.use_buf(ctx0, 1); - // cur = ffn_norm*cur cur = ggml_mul(ctx0, ggml_repeat(ctx0, model.layers[il].ffn_norm, cur), cur); } - lctx.use_buf(ctx0, 0); - struct ggml_tensor * tmp = ggml_mul_mat(ctx0, model.layers[il].w3, cur); @@ -1009,8 +956,6 @@ static bool llama_eval_internal( // SILU activation cur = ggml_silu(ctx0, cur); - lctx.use_buf(ctx0, 1); - cur = ggml_mul(ctx0, cur, tmp); cur = ggml_mul_mat(ctx0, @@ -1018,25 +963,22 @@ static bool llama_eval_internal( cur); } - lctx.use_buf(ctx0, 3); - cur = ggml_add(ctx0, cur, inpFF); // input for next layer inpL = cur; } + lctx.use_buf(ctx0, 0); + // used at the end to optionally extract the embeddings struct ggml_tensor * embeddings = NULL; // norm { - lctx.use_buf(ctx0, 0); inpL = ggml_rms_norm(ctx0, inpL); - lctx.use_buf(ctx0, 1); - // inpL = norm*inpL inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm, inpL), @@ -1045,8 +987,6 @@ static bool llama_eval_internal( embeddings = inpL; } - lctx.use_buf(ctx0, 0); - // lm_head inpL = ggml_mul_mat(ctx0, model.output, inpL); @@ -1097,11 +1037,7 @@ static bool llama_eval_internal( printf("\n%s: used_mem = %.3f MB, scratch -- %.3f MB, %.3f MB %.3f MB %.3f %.3f %.3f MB\n", __func__, ggml_used_mem(ctx0)/1024.0/1024.0, lctx.get_buf_max_mem(0)/1024.0/1024.0, - lctx.get_buf_max_mem(1)/1024.0/1024.0, - lctx.get_buf_max_mem(2)/1024.0/1024.0, - lctx.get_buf_max_mem(3)/1024.0/1024.0, - lctx.get_buf_max_mem(4)/1024.0/1024.0, - lctx.get_buf_max_mem(5)/1024.0/1024.0); + lctx.get_buf_max_mem(1)/1024.0/1024.0); #endif ggml_free(ctx0); @@ -1722,10 +1658,6 @@ struct llama_context * llama_init_from_file( ctx->buf_scratch[0].resize(MEM_REQ_SCRATCH0.at(ctx->model.type)); ctx->buf_scratch[1].resize(MEM_REQ_SCRATCH1.at(ctx->model.type)); - ctx->buf_scratch[2].resize(MEM_REQ_SCRATCH2.at(ctx->model.type)); - ctx->buf_scratch[3].resize(MEM_REQ_SCRATCH3.at(ctx->model.type)); - ctx->buf_scratch[4].resize(MEM_REQ_SCRATCH4.at(ctx->model.type)); - ctx->buf_scratch[5].resize(MEM_REQ_SCRATCH5.at(ctx->model.type)); } return ctx;