diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 7994603c4..88ae0ee6e 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -442,33 +442,6 @@ static std::vector hellaswag_evaluate_tokens( return result; } -// decode in batches of ctx_params.n_batch tokens -static bool decode_helper(llama_context * ctx, llama_batch & batch, int32_t n_batch) { - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - 0, 0, 0, // unused - }; - - const int ret = llama_decode(ctx, batch_view); - if (ret != 0) { - LOG_TEE("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); - return false; - } - } - - return true; -} - - static void hellaswag_score(llama_context * ctx, const gpt_params & params) { // Calculates hellaswag score (acc_norm) from prompt // @@ -589,7 +562,6 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { const int n_vocab = llama_n_vocab(llama_get_model(ctx)); const int n_ctx = llama_n_ctx(ctx); - const int n_batch = params.n_batch; GGML_ASSERT(params.n_batch >= n_ctx && "HellaSwag currently requires n_batch >= n_ctx"); @@ -648,7 +620,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) { llama_kv_cache_clear(ctx); - if (!decode_helper(ctx, batch, n_batch)) { + if (llama_decode(ctx, batch) != 0) { fprintf(stderr, "%s: llama_decode() failed\n", __func__); return; }