From 6ab116ac5a90ebc4fd51e0683703db5cfaf88dfa Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Tue, 22 Oct 2024 13:01:22 +0200 Subject: [PATCH] move batch_allocr inside decode/encode_internal --- src/llama.cpp | 33 ++++++++++++++++----------------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index d64200402..4d424bfaf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17108,16 +17108,19 @@ static void llama_graph_compute( // static int llama_decode_internal( llama_context & lctx, - llama_batch batch) { + llama_batch inp_batch) { lctx.is_encoding = false; - const uint32_t n_tokens_all = batch.n_tokens; - if (n_tokens_all == 0) { + if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } + llama_batch_allocr batch_allocr(lctx, inp_batch); + llama_batch batch = batch_allocr.batch; + const uint32_t n_tokens_all = batch.n_tokens; + const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -17422,17 +17425,19 @@ static int llama_decode_internal( // static int llama_encode_internal( llama_context & lctx, - llama_batch batch) { + llama_batch inp_batch) { lctx.is_encoding = true; - const uint32_t n_tokens = batch.n_tokens; - - if (n_tokens == 0) { + if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } + llama_batch_allocr batch_allocr(lctx, inp_batch); + llama_batch batch = batch_allocr.batch; + const uint32_t n_tokens = batch.n_tokens; + const auto & model = lctx.model; const auto & hparams = model.hparams; const auto & cparams = lctx.cparams; @@ -21137,16 +21142,12 @@ struct llama_batch_allocr { std::vector logits; struct llama_batch batch; // optionally fulfill the batch returned by llama_batch_get_one - llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) { + llama_batch_allocr(llama_context & ctx, struct llama_batch in_batch) { batch = in_batch; - if (batch.n_tokens == 0) { - // llama_(de|en)code_internal will return an error in this case - return; - } if (!batch.pos) { // determine the last position in KV cache llama_pos last_pos = -1; - for (const auto & cell : ctx->kv_self.cells) { + for (const auto & cell : ctx.kv_self.cells) { if (cell.has_seq_id(batch_default_seq_id)) { last_pos = std::max(last_pos, cell.pos); } @@ -21184,8 +21185,7 @@ struct llama_batch_allocr { int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch) { - llama_batch_allocr batch_allocr(ctx, batch); - const int ret = llama_encode_internal(*ctx, batch_allocr.batch); + const int ret = llama_encode_internal(*ctx, batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -21196,8 +21196,7 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { - llama_batch_allocr batch_allocr(ctx, batch); - const int ret = llama_decode_internal(*ctx, batch_allocr.batch); + const int ret = llama_decode_internal(*ctx, batch); if (ret != 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); }