From b791d1f489f22011b970353e27fefc0d711f5694 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Tue, 22 Aug 2023 15:55:52 +0300 Subject: [PATCH] Implementing strided computation of perplexity --- common/common.cpp | 6 ++ common/common.h | 2 + examples/perplexity/perplexity.cpp | 116 +++++++++++++++++++++++++++++ 3 files changed, 124 insertions(+) diff --git a/common/common.cpp b/common/common.cpp index d7e1a5725..6b937f729 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -418,6 +418,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.antiprompt.push_back(argv[i]); } else if (arg == "--perplexity") { params.perplexity = true; + } else if (arg == "--ppl-stride") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.ppl_stride = std::stoi(argv[i]); } else if (arg == "--hellaswag") { params.hellaswag = true; } else if (arg == "--hellaswag-tasks") { diff --git a/common/common.h b/common/common.h index c50a6edfc..184092f06 100644 --- a/common/common.h +++ b/common/common.h @@ -64,6 +64,8 @@ struct gpt_params { std::string lora_adapter = ""; // lora adapter path std::string lora_base = ""; // base model path for the lora adapter + int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. + // bool hellaswag = false; // compute HellaSwag score over random tasks from datafile supplied in prompt size_t hellaswag_tasks = 400; // number of tasks to use when computing the HellaSwag score diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index f3c045aec..22425e14f 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -27,7 +27,117 @@ std::vector softmax(const std::vector& logits) { return probs; } +void perplexity_v2(llama_context * ctx, const gpt_params & params) { + + // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research + // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` + // Output: `perplexity: 13.5106 [114/114]` + // BOS tokens will be added for each chunk before eval + + if (params.ppl_stride <= 0) { + fprintf(stderr, "%s: stride is %d but must be greater than zero!\n",__func__,params.ppl_stride); + return; + } + auto tokens = ::llama_tokenize(ctx, params.prompt, true); + + const int calc_chunk = params.n_ctx; + + fprintf(stderr, "%s: have %zu tokens. Calculation chunk = %d\n", __func__, tokens.size(), calc_chunk); + + if (int(tokens.size()) <= calc_chunk) { + fprintf(stderr, "%s: there are only %zu tokens, this is not enough for a context size of %d and stride %d\n",__func__, + tokens.size(), params.n_ctx, params.ppl_stride); + return; + } + + const int n_chunk_max = (tokens.size() - calc_chunk + params.ppl_stride - 1) / params.ppl_stride; + + const int n_chunk = params.n_chunks < 0 ? n_chunk_max : std::min(params.n_chunks, n_chunk_max); + const int n_vocab = llama_n_vocab(ctx); + const int n_batch = params.n_batch; + + int count = 0; + double nll = 0.0; + + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + + for (int i = 0; i < n_chunk; ++i) { + const int start = i * params.ppl_stride; + const int end = start + calc_chunk; + + const int num_batches = (calc_chunk + n_batch - 1) / n_batch; + //fprintf(stderr, "%s: evaluating %d...%d using %d batches\n", __func__, start, end, num_batches); + + std::vector logits; + + const auto t_start = std::chrono::high_resolution_clock::now(); + + for (int j = 0; j < num_batches; ++j) { + const int batch_start = start + j * n_batch; + const int batch_size = std::min(end - batch_start, n_batch); + + //fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); + if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { + //fprintf(stderr, "%s : failed to eval\n", __func__); + return; + } + + // save original token and restore it after eval + const auto token_org = tokens[batch_start]; + + // add BOS token for the first batch of each chunk + if (j == 0) { + tokens[batch_start] = llama_token_bos(ctx); + } + + const auto batch_logits = llama_get_logits(ctx); + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + + if (j == 0) { + tokens[batch_start] = token_org; + } + } + + const auto t_end = std::chrono::high_resolution_clock::now(); + + if (i == 0) { + const float t_total = std::chrono::duration(t_end - t_start).count(); + fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + int total_seconds = (int)(t_total * n_chunk); + if (total_seconds >= 60*60) { + fprintf(stderr, "%d hours ", total_seconds / (60*60)); + total_seconds = total_seconds % (60*60); + } + fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); + } + + //fprintf(stderr, "%s: using tokens %d...%d\n",__func__,params.n_ctx - params.ppl_stride + start, params.n_ctx + start); + for (int j = params.n_ctx - params.ppl_stride - 1; j < params.n_ctx - 1; ++j) { + + // Calculate probability of next token, given the previous ones. + const std::vector tok_logits( + logits.begin() + (j + 0) * n_vocab, + logits.begin() + (j + 1) * n_vocab); + + const float prob = softmax(tok_logits)[tokens[start + j + 1]]; + + nll += -std::log(prob); + ++count; + } + // perplexity is e^(average negative log-likelihood) + printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); + fflush(stdout); + } + printf("\n"); +} + void perplexity(llama_context * ctx, const gpt_params & params) { + + if (params.ppl_stride > 0) { + perplexity_v2(ctx, params); + return; + } + // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` @@ -369,6 +479,12 @@ int main(int argc, char ** argv) { params.perplexity = true; params.n_batch = std::min(params.n_batch, params.n_ctx); + if (params.ppl_stride > 0) { + fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n", + params.n_ctx, params.n_ctx + params.ppl_stride/2); + params.n_ctx += params.ppl_stride/2; + } + if (params.n_ctx > 2048) { fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);" "expect poor results\n", __func__, params.n_ctx);