From 5cf23d11c8b9cfb2da5698e03d35e094b3bee1ea Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 3 Jul 2024 16:10:36 +0200 Subject: [PATCH 1/2] ppl : fix n_seq_max for perplexity --- examples/perplexity/perplexity.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index efde8dfdf..dea694165 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1991,6 +1991,8 @@ int main(int argc, char ** argv) { params.n_batch = std::min(params.n_batch, n_kv); } else { params.n_batch = std::min(params.n_batch, params.n_ctx); + // ensure there's at least enough seq_ids for HellaSwag + params.n_parallel = std::max(4, params.n_parallel); } if (params.ppl_stride > 0) { @@ -2015,9 +2017,6 @@ int main(int argc, char ** argv) { llama_model * model; llama_context * ctx; - // ensure there's at least enough seq_ids for HellaSwag - params.n_parallel = std::max(4, params.n_parallel); - // load the model and apply lora adapter, if any std::tie(model, ctx) = llama_init_from_gpt_params(params); if (model == NULL) { From dcab343f2f06223218635b5975829610d2422541 Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 3 Jul 2024 16:22:58 +0200 Subject: [PATCH 2/2] use 1 seq for kl_divergence --- examples/perplexity/perplexity.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index dea694165..dbe445391 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -1991,8 +1991,12 @@ int main(int argc, char ** argv) { params.n_batch = std::min(params.n_batch, n_kv); } else { params.n_batch = std::min(params.n_batch, params.n_ctx); - // ensure there's at least enough seq_ids for HellaSwag - params.n_parallel = std::max(4, params.n_parallel); + if (params.kl_divergence) { + params.n_parallel = 1; + } else { + // ensure there's at least enough seq_ids for HellaSwag + params.n_parallel = std::max(4, params.n_parallel); + } } if (params.ppl_stride > 0) {