From 0451b1f9ef8a2f95f0360fe881054d249a3f5a3b Mon Sep 17 00:00:00 2001 From: Jia Liu Date: Thu, 22 Aug 2024 14:33:01 +0800 Subject: [PATCH] re-use same llama_batch --- examples/perplexity/perplexity.cpp | 8 ++++++-- include/llama.h | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 9fec60878..ae21bfbaf 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -385,6 +385,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & double nll = 0.0; const int num_batches = (n_ctx + n_batch - 1) / n_batch; + llama_batch batch = llama_batch_init(n_batch, 0, 1); fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); @@ -403,7 +404,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - llama_batch batch = llama_batch_init(batch_size, 0, 1); + llama_batch_clear(batch); + for (int k = 0; k < batch_size; ++k) { const int idx = batch_start + k; llama_batch_add(batch, tokens[idx], j*n_batch + k, {0}, true); @@ -415,7 +417,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & return {tokens, -1, logit_history, prob_history}; } - llama_batch_free(batch); // save original token and restore it after eval const auto token_org = tokens[batch_start]; @@ -468,6 +469,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & } fflush(stdout); } + + llama_batch_free(batch); + printf("\n"); return {tokens, std::exp(nll / count), logit_history, prob_history}; diff --git a/include/llama.h b/include/llama.h index d1b2238c3..e942c9d3f 100644 --- a/include/llama.h +++ b/include/llama.h @@ -230,7 +230,7 @@ extern "C" { llama_pos * pos; int32_t * n_seq_id; llama_seq_id ** seq_id; - int8_t * output; // Previously named 'logits', renamed to 'output' now. + int8_t * output; // Previously named 'logits', renamed to 'output' now. // NOTE: helpers for smooth API transition - can be deprecated in the future // for future-proof code, use the above fields instead and ignore everything below