From 4351c4632daeed3c412f901b0096919c7a53454b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 13:51:02 +0200 Subject: [PATCH] perplexity : clean-up ggml-ci --- examples/perplexity/perplexity.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 8ed58ec28..7994603c4 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -591,23 +591,27 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const int n_ctx = llama_n_ctx(ctx); const int n_batch = params.n_batch; + GGML_ASSERT(params.n_batch >= n_ctx && "HellaSwag currently requires n_batch >= n_ctx"); + const int max_tasks_per_batch = 32; const int max_seq = 4*max_tasks_per_batch; - llama_batch batch = llama_batch_init(n_batch, 0, max_seq); - - std::vector> ending_tokens(4); + llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); std::vector tok_logits(n_vocab); for (size_t i0 = 0; i0 < hs_task_count; i0++) { int n_cur = 0; + size_t i1 = i0; - size_t i_batch = 0; + size_t i_batch = 0; // this tells us where in `llama_batch` we are currently llama_batch_clear(batch); // batch as much tasks as possible into the available context + // each task has 4 unique seuqnce ids + // the common prefix is shared among the 4 sequences to save tokens + // we extract logits only from the last common token and all ending tokens of each sequence while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) { auto & hs_cur = hs_data[i1]; @@ -649,6 +653,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { return; } + // evaluate the computed tasks for (size_t i = i0; i < i1; ++i) { auto & hs_cur = hs_data[i];