From 9df62c25f7db072ecb88252b61076403c2051e32 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 18 Jan 2024 14:58:13 +0200 Subject: [PATCH] perplexity : remove HellaSwag restruction for n_batch --- examples/perplexity/perplexity.cpp | 37 ++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 343eeaa45..ea2c8026c 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -564,8 +564,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_ctx = llama_n_ctx(ctx); - - GGML_ASSERT(params.n_batch >= n_ctx && "HellaSwag currently requires n_batch >= n_ctx"); + const int n_batch = params.n_batch; const int max_tasks_per_batch = params.n_parallel; const int max_seq = 4*max_tasks_per_batch; @@ -573,6 +572,34 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); std::vector tok_logits(n_vocab); + std::vector batch_logits(n_ctx*n_vocab); + + auto decode_helper = [&](llama_context * ctx, llama_batch & batch, int32_t n_batch) { + for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); + + llama_batch batch_view = { + n_tokens, + batch.token + i, + nullptr, + batch.pos + i, + batch.n_seq_id + i, + batch.seq_id + i, + batch.logits + i, + 0, 0, 0, // unused + }; + + const int ret = llama_decode(ctx, batch_view); + if (ret != 0) { + LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); + return false; + } + + memcpy(batch_logits.data() + i*n_vocab, llama_get_logits(ctx), n_tokens*n_vocab*sizeof(float)); + } + + return true; + }; for (size_t i0 = 0; i0 < hs_task_count; i0++) { int n_cur = 0; @@ -622,7 +649,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { llama_kv_cache_clear(ctx); // decode all tasks [i0, i1) - if (llama_decode(ctx, batch) != 0) { + if (!decode_helper(ctx, batch, n_batch)) { fprintf(stderr, "%s: llama_decode() failed\n", __func__); return; } @@ -631,7 +658,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { for (size_t i = i0; i < i1; ++i) { auto & hs_cur = hs_data[i]; - std::memcpy(tok_logits.data(), llama_get_logits_ith(ctx, hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float)); + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + hs_cur.common_prefix - 1), n_vocab*sizeof(float)); const auto first_probs = softmax(tok_logits); @@ -643,7 +670,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { // Calculate the logprobs over the ending for (size_t j = hs_cur.common_prefix; j < hs_cur.seq_tokens[s].size() - 1; j++) { - std::memcpy(tok_logits.data(), llama_get_logits_ith(ctx, hs_cur.i_batch + li++), n_vocab*sizeof(float)); + std::memcpy(tok_logits.data(), batch_logits.data() + n_vocab*(hs_cur.i_batch + li++), n_vocab*sizeof(float)); const float prob = softmax(tok_logits)[hs_cur.seq_tokens[s][j + 1]];