From 734f9e29de8421afac198b86ad454937d94e672c Mon Sep 17 00:00:00 2001 From: Xuan Son Nguyen Date: Sat, 12 Oct 2024 22:51:30 +0200 Subject: [PATCH] use common_batch_add, reuse llama_batch in loop --- examples/imatrix/imatrix.cpp | 13 ++++++------- examples/perplexity/perplexity.cpp | 13 ++++++------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 1e97d2980..70ff47768 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -496,6 +496,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_cache_clear(ctx); + llama_batch batch = llama_batch_init(n_batch, 0, 1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -508,12 +510,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } - llama_batch batch = llama_batch_init(batch_size, 0, 1); + common_batch_clear(batch); for (int i = 0; i < batch_size; i++) { - batch. token[i] = tokens[batch_start + i]; - batch. pos[i] = j*n_batch + i; - batch.logits[i] = true; - batch.seq_id[i][0] = 0; + common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); } if (llama_decode(ctx, batch)) { @@ -522,8 +521,6 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { return false; } - llama_batch_free(batch); - // restore the original token in case it was set to BOS tokens[batch_start] = token_org; @@ -533,6 +530,8 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { } } + llama_batch_free(batch); + const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 181a3c86d..252ef56ba 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1800,6 +1800,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_cache_clear(ctx); + llama_batch batch = llama_batch_init(n_batch, 0, 1); + for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); @@ -1812,12 +1814,9 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); } - llama_batch batch = llama_batch_init(batch_size, 0, 1); + common_batch_clear(batch); for (int i = 0; i < batch_size; i++) { - batch. token[i] = tokens[batch_start + i]; - batch. pos[i] = j*n_batch + i; - batch.logits[i] = true; - batch.seq_id[i][0] = 0; + common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); } if (llama_decode(ctx, batch)) { @@ -1826,8 +1825,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { return; } - llama_batch_free(batch); - // restore the original token in case it was set to BOS tokens[batch_start] = token_org; @@ -1837,6 +1834,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } } + llama_batch_free(batch); + const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) {