print tested n_ctx, add assert

This commit is contained in:
slaren 2024-03-09 19:43:08 +01:00
parent ac07f7d0f7
commit 23dbcfa2c8

View file

@ -498,9 +498,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
double nll2 = 0.0; double nll2 = 0.0;
const int num_batches = (n_ctx + n_batch - 1) / n_batch; const int num_batches = (n_ctx + n_batch - 1) / n_batch;
const int n_seq = std::max(1, n_batch / n_ctx);
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
const int n_seq = std::max(1, n_batch / n_ctx); GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
std::vector<float> logits; std::vector<float> logits;
@ -508,7 +510,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
logits.reserve((size_t)n_ctx * n_vocab); logits.reserve((size_t)n_ctx * n_vocab);
} }
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_batch, n_seq); fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1); std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);