From 30ebd9472354ea83b7ebaa0f9830c9079e8f1995 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 14:20:11 +0200 Subject: [PATCH] perplexity : add comments --- examples/perplexity/perplexity.cpp | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 0b0ef9a10..218ab6746 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -583,9 +583,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { llama_batch_clear(batch); // batch as much tasks as possible into the available context - // each task has 4 unique seuqnce ids + // each task has 4 unique seuqnce ids - one for each ending // the common prefix is shared among the 4 sequences to save tokens - // we extract logits only from the last common token and all ending tokens of each sequence + // we extract logits only from the last common token and from all ending tokens of each sequence while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) { auto & hs_cur = hs_data[i1]; @@ -598,7 +598,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { for (size_t i = 0; i < hs_cur.common_prefix; ++i) { llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); } - batch.logits[batch.n_tokens - 1] = true; + batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix for (int s = 0; s < 4; ++s) { for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) { @@ -622,12 +622,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { llama_kv_cache_clear(ctx); + // decode all tasks [i0, i1) if (llama_decode(ctx, batch) != 0) { fprintf(stderr, "%s: llama_decode() failed\n", __func__); return; } - // evaluate the computed tasks + // compute the logprobs for each ending of the decoded tasks for (size_t i = i0; i < i1; ++i) { auto & hs_cur = hs_data[i];