diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 90dbfb6e0..5cfab3b18 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -25,46 +25,56 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // Download: https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-raw-v1.zip?ref=salesforce-research // Run `./perplexity -m models/7B/ggml-model-q4_0.bin -f wiki.test.raw` // Output: `perplexity: 13.5106 [114/114]` + // BOS tokens will be added for each chunk before eval auto tokens = ::llama_tokenize(ctx, params.prompt, true); - int count = 0; - int seq_count = tokens.size() / params.n_ctx; - int n_vocab = llama_n_vocab(ctx); + int count = 0; + + const int n_chunk = tokens.size() / params.n_ctx; + const int n_vocab = llama_n_vocab(ctx); + const int n_batch = params.n_batch; double nll = 0.0; - fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch); + fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); - for (int i = 0; i < seq_count; ++i) { - const int start = i * params.n_ctx; + for (int i = 0; i < n_chunk; ++i) { + const int start = i * params.n_ctx; const int end = start + params.n_ctx; - std::vector logits; - const int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch; + const int num_batches = (params.n_ctx + n_batch - 1) / n_batch; - const auto start_t = std::chrono::high_resolution_clock::now(); + std::vector logits; + + const auto t_start = std::chrono::high_resolution_clock::now(); for (int j = 0; j < num_batches; ++j) { - const int batch_start = start + j * params.n_batch; - const int batch_size = std::min(end - batch_start, params.n_batch); + const int batch_start = start + j * n_batch; + const int batch_size = std::min(end - batch_start, n_batch); - // TODO: not perfect since this can be in the middle of a word, but it is better than nothing - tokens[batch_start] = llama_token_bos(); + // save original token and restore it after eval + const auto token_org = tokens[batch_start]; - if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) { + if (j == 0) { + tokens[batch_start] = llama_token_bos(); + } + + if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } + tokens[batch_start] = token_org; + const auto batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); } - const auto end_t = std::chrono::high_resolution_clock::now(); + const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { - const float seconds = std::chrono::duration(end_t - start_t).count(); - fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, seconds); - int total_seconds = (int)(seconds * seq_count); + const float t_total = std::chrono::duration(t_end - t_start).count(); + fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); + int total_seconds = (int)(t_total * n_chunk); if (total_seconds >= 60*60) { fprintf(stderr, "%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); @@ -74,7 +84,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // We get the logits for all the tokens in the context window (params.n_ctx) // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, - // calculate the perplexity over the last half the window (so the model always has + // calculate the perplexity over the last half of the window (so the model always has // some context to predict the token). // // We rely on the fact that attention in the forward pass only looks at previous @@ -86,10 +96,12 @@ void perplexity(llama_context * ctx, const gpt_params & params) { // process the entire prompt. for (int j = std::min(512, params.n_ctx / 2); j < params.n_ctx - 1; ++j) { // Calculate probability of next token, given the previous ones. - std::vector tok_logits( - logits.begin() + j * n_vocab, + const std::vector tok_logits( + logits.begin() + (j + 0) * n_vocab, logits.begin() + (j + 1) * n_vocab); - float prob = softmax(tok_logits)[tokens[start + j + 1]]; + + const float prob = softmax(tok_logits)[tokens[start + j + 1]]; + nll += -std::log(prob); ++count; }