diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 10028e6c2..293eb52c3 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -498,9 +498,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par double nll2 = 0.0; const int num_batches = (n_ctx + n_batch - 1) / n_batch; + const int n_seq = std::max(1, n_batch / n_ctx); GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); - const int n_seq = std::max(1, n_batch / n_ctx); + GGML_ASSERT(params.n_ctx == n_seq * n_ctx); + llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); std::vector logits; @@ -508,7 +510,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par logits.reserve((size_t)n_ctx * n_vocab); } - fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_batch, n_seq); + fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); std::vector workers(std::thread::hardware_concurrency() - 1);