From ac07f7d0f7453a6363377db0865f5edca4268aa3 Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 9 Mar 2024 03:42:37 +0100 Subject: [PATCH] set cparams.n_parallel to the number of sequences --- examples/perplexity/perplexity.cpp | 4 +++- llama.cpp | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index e8b19b046..10028e6c2 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1822,7 +1822,9 @@ int main(int argc, char ** argv) { const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence; if (ppl) { - int32_t n_kv = std::max(1, params.n_batch / n_ctx) * n_ctx; + int n_seq = std::max(1, params.n_batch / n_ctx); + int32_t n_kv = n_seq * n_ctx; + params.n_parallel = n_seq; params.n_ctx = n_kv; params.n_batch = std::min(params.n_batch, n_kv); } else { diff --git a/llama.cpp b/llama.cpp index 775de9985..2a0b1c58a 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8943,7 +8943,7 @@ static int llama_decode_internal( } } #ifndef NDEBUG - logits_valid[i] = batch.logits[i] == 1; + logits_valid[i] = batch.logits[i] != 0; #endif } } else if (lctx.logits_all) {