From 7740c969d0470b91d57168651e976deb7cda4d9d Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Fri, 11 Oct 2024 13:59:26 +0200 Subject: [PATCH] fix --- src/llama.cpp | 35 ++++++++++++++++++----------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index c3646eb4d..c25ae1e1e 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -21081,58 +21081,59 @@ void llama_batch_free(struct llama_batch batch) { } // temporary allocate memory for the input batch if needed +static const llama_seq_id batch_default_seq_id = 0; struct llama_batch_allocr { - static const llama_seq_id default_seq_id = 0; - std::array seq_id_0 = {default_seq_id}; + std::array seq_id_0 = {batch_default_seq_id}; std::vector pos; std::vector n_seq_id; std::vector seq_id; std::vector logits; - // fulfill the batch returned by llama_batch_get_one - struct llama_batch get_fulfilled_batch(struct llama_context * ctx, struct llama_batch in_batch) { - struct llama_batch batch = in_batch; + struct llama_batch batch; + // optionally fulfill the batch returned by llama_batch_get_one + llama_batch_allocr(struct llama_context * ctx, struct llama_batch in_batch) { + batch = in_batch; if (!batch.pos) { // determine the last position in KV cache llama_pos last_pos = 0; for (const auto & cell : ctx->kv_self.cells) { - if (cell.seq_id.find(default_seq_id) != cell.seq_id.end()) { + if (cell.has_seq_id(batch_default_seq_id)) { last_pos = std::max(last_pos, cell.pos); } } + last_pos++; // next position pos.resize(batch.n_tokens); - for (int32_t i = 1; i <= batch.n_tokens; i++) { + for (int32_t i = 0; i < batch.n_tokens; i++) { pos[i] = i+last_pos; } batch.pos = pos.data(); } if (!batch.n_seq_id) { - n_seq_id.reserve(batch.n_tokens); - for (int32_t i = 1; i <= batch.n_tokens; i++) { + n_seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { n_seq_id[i] = seq_id_0.size(); } batch.n_seq_id = n_seq_id.data(); } if (!batch.seq_id) { - seq_id.reserve(batch.n_tokens); - for (int32_t i = 1; i <= batch.n_tokens; i++) { + seq_id.resize(batch.n_tokens); + for (int32_t i = 0; i < batch.n_tokens; i++) { seq_id[i] = seq_id_0.data(); } batch.seq_id = seq_id.data(); } if (!batch.logits) { - logits.reserve(batch.n_tokens); + logits.resize(batch.n_tokens); logits[logits.size() - 1] = true; batch.logits = logits.data(); } - return batch; } }; int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch) { - llama_batch_allocr batch_allocr; - const int ret = llama_encode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch)); + llama_batch_allocr batch_allocr(ctx, batch); + const int ret = llama_encode_internal(*ctx, batch_allocr.batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); } @@ -21143,8 +21144,8 @@ int32_t llama_encode( int32_t llama_decode( struct llama_context * ctx, struct llama_batch batch) { - llama_batch_allocr batch_allocr; - const int ret = llama_decode_internal(*ctx, batch_allocr.get_fulfilled_batch(ctx, batch)); + llama_batch_allocr batch_allocr(ctx, batch); + const int ret = llama_decode_internal(*ctx, batch_allocr.batch); if (ret < 0) { LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); }