re-use same llama_batch

This commit is contained in:
Jia Liu 2024-08-22 14:33:01 +08:00
parent 27ecb076cf
commit 0451b1f9ef
2 changed files with 7 additions and 3 deletions

View file

@ -385,6 +385,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
double nll = 0.0; double nll = 0.0;
const int num_batches = (n_ctx + n_batch - 1) / n_batch; const int num_batches = (n_ctx + n_batch - 1) / n_batch;
llama_batch batch = llama_batch_init(n_batch, 0, 1);
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
@ -403,7 +404,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const int batch_start = start + j * n_batch; const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch); const int batch_size = std::min(end - batch_start, n_batch);
llama_batch batch = llama_batch_init(batch_size, 0, 1); llama_batch_clear(batch);
for (int k = 0; k < batch_size; ++k) { for (int k = 0; k < batch_size; ++k) {
const int idx = batch_start + k; const int idx = batch_start + k;
llama_batch_add(batch, tokens[idx], j*n_batch + k, {0}, true); llama_batch_add(batch, tokens[idx], j*n_batch + k, {0}, true);
@ -415,7 +417,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {tokens, -1, logit_history, prob_history}; return {tokens, -1, logit_history, prob_history};
} }
llama_batch_free(batch);
// save original token and restore it after eval // save original token and restore it after eval
const auto token_org = tokens[batch_start]; const auto token_org = tokens[batch_start];
@ -468,6 +469,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
} }
fflush(stdout); fflush(stdout);
} }
llama_batch_free(batch);
printf("\n"); printf("\n");
return {tokens, std::exp(nll / count), logit_history, prob_history}; return {tokens, std::exp(nll / count), logit_history, prob_history};

View file

@ -230,7 +230,7 @@ extern "C" {
llama_pos * pos; llama_pos * pos;
int32_t * n_seq_id; int32_t * n_seq_id;
llama_seq_id ** seq_id; llama_seq_id ** seq_id;
int8_t * output; // Previously named 'logits', renamed to 'output' now. int8_t * output; // Previously named 'logits', renamed to 'output' now.
// NOTE: helpers for smooth API transition - can be deprecated in the future // NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below // for future-proof code, use the above fields instead and ignore everything below