print tested n_ctx, add assert
This commit is contained in:
parent
ac07f7d0f7
commit
23dbcfa2c8
1 changed files with 4 additions and 2 deletions
|
@ -498,9 +498,11 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
double nll2 = 0.0;
|
||||
|
||||
const int num_batches = (n_ctx + n_batch - 1) / n_batch;
|
||||
const int n_seq = std::max(1, n_batch / n_ctx);
|
||||
|
||||
GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0);
|
||||
const int n_seq = std::max(1, n_batch / n_ctx);
|
||||
GGML_ASSERT(params.n_ctx == n_seq * n_ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1);
|
||||
|
||||
std::vector<float> logits;
|
||||
|
@ -508,7 +510,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
logits.reserve((size_t)n_ctx * n_vocab);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_batch, n_seq);
|
||||
fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq);
|
||||
|
||||
std::vector<std::thread> workers(std::thread::hardware_concurrency() - 1);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue