From 7392ad629d60ee81a44dfdac6289a1b87d7f404f Mon Sep 17 00:00:00 2001 From: Gary Linscott Date: Sat, 25 Mar 2023 13:30:40 -0700 Subject: [PATCH] update from merge --- examples/perplexity/perplexity.cpp | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index f617ba365..91f0bf6b9 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -26,17 +26,26 @@ void perplexity(llama_context * ctx, const gpt_params & params) { int count = 0; double nll = 0.0; int seq_count = tokens.size() / params.n_ctx; + int n_vocab = llama_n_vocab(ctx); - fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count); + fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch); for (int i = 0; i < seq_count; ++i) { int start = i * params.n_ctx; int end = start + params.n_ctx - 1; - std::vector embd(tokens.begin() + start, tokens.begin() + end); + + std::vector logits; + int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch; auto start_t = std::chrono::high_resolution_clock::now(); - if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return; + for (int j = 0; j < num_batches; ++j) { + int batch_start = start + j * params.n_batch; + int batch_size = std::min(end - batch_start, params.n_batch); + if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return; + } + auto batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } auto end_t = std::chrono::high_resolution_clock::now(); if (i == 0) { @@ -56,13 +65,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // last 256 tokens. Then, we split the input up into context window size chunks to // process the entire prompt. - auto logits = llama_get_logits(ctx); for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) { // Calculate probability of next token, given the previous ones. - int n_vocab = llama_n_vocab(ctx); std::vector tok_logits( - logits + j * n_vocab, - logits + (j + 1) * n_vocab); + logits.begin() + j * n_vocab, + logits.begin() + (j + 1) * n_vocab); double prob = softmax(tok_logits)[tokens[start + j + 1]]; nll += -std::log(prob); ++count;