re-use same llama_batch

This commit is contained in:
Jia Liu 2024-08-22 14:33:01 +08:00
parent 27ecb076cf
commit 0451b1f9ef
2 changed files with 7 additions and 3 deletions

View file

@ -385,6 +385,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
double nll = 0.0;
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
llama_batch batch = llama_batch_init(n_batch, 0, 1);
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch);
@ -403,7 +404,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const int batch_start = start + j * n_batch;
const int batch_size = std::min(end - batch_start, n_batch);
llama_batch batch = llama_batch_init(batch_size, 0, 1);
llama_batch_clear(batch);
for (int k = 0; k < batch_size; ++k) {
const int idx = batch_start + k;
llama_batch_add(batch, tokens[idx], j*n_batch + k, {0}, true);
@ -415,7 +417,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
return {tokens, -1, logit_history, prob_history};
}
llama_batch_free(batch);
// save original token and restore it after eval
const auto token_org = tokens[batch_start];
@ -468,6 +469,9 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
}
fflush(stdout);
}
llama_batch_free(batch);
printf("\n");
return {tokens, std::exp(nll / count), logit_history, prob_history};

View file

@ -230,7 +230,7 @@ extern "C" {
llama_pos * pos;
int32_t * n_seq_id;
llama_seq_id ** seq_id;
int8_t * output; // Previously named 'logits', renamed to 'output' now.
int8_t * output; // Previously named 'logits', renamed to 'output' now.
// NOTE: helpers for smooth API transition - can be deprecated in the future
// for future-proof code, use the above fields instead and ignore everything below