From 545862ae48ed30060370dfdfe4d751c927d1dc44 Mon Sep 17 00:00:00 2001 From: klosax <131523366+klosax@users.noreply.github.com> Date: Fri, 21 Jul 2023 21:25:44 +0200 Subject: [PATCH] Update perplexity.cpp --- examples/perplexity/perplexity.cpp | 85 +++++++++++++++++++++++++++++- 1 file changed, 83 insertions(+), 2 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index bfad99939..9c58683f5 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -120,6 +120,83 @@ void perplexity(llama_context * ctx, const gpt_params & params) { printf("\n"); } +void perplexity_lines(llama_context * ctx, const gpt_params & params) { + // Calculates perplexity over each line of the prompt + + std::vector prompt_lines; + + size_t pos=0; + while( pos < params.prompt.size() ) { + std::string line; + while( true ) { + if( params.prompt[pos] == '\n' || pos == params.prompt.size() ) + break; + line += params.prompt[pos++]; + } + pos++; + prompt_lines.push_back( line ); + } + + const int n_vocab = llama_n_vocab(ctx); + + int counttotal = 0; + size_t n_lines = prompt_lines.size(); + + double nll = 0.0; + + fprintf(stderr, "%s: calculating perplexity over %lu lines\n", __func__, n_lines); + + printf("\nLine\tPPL line\tPPL cumulative\n"); + + for (size_t i = 0; i < n_lines; ++i) { + + // Tokenize and insert BOS at start + std::vector batch_embd = ::llama_tokenize(ctx, prompt_lines[i], true); + + size_t batch_size = batch_embd.size(); + + // Stop if line is too long + if( batch_size > (size_t)params.n_ctx ) { + fprintf(stderr, "%s : tokens in line %lu > n_ctxl\n", __func__, i); + return; + } + + if (llama_eval(ctx, batch_embd.data(), batch_size, 0, params.n_threads)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return; + } + + const auto batch_logits = llama_get_logits(ctx); + std::vector logits; + logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); + + double nllline = 0.0; + int countline = 0; + + // Perplexity over second half of the line + for (size_t j = batch_size/2; j < batch_size - 1; ++j) { + // Calculate probability of next token, given the previous ones. + const std::vector tok_logits( + logits.begin() + (j + 0) * n_vocab, + logits.begin() + (j + 1) * n_vocab); + + const float prob = softmax(tok_logits)[batch_embd[ j + 1]]; + + nllline += -std::log(prob); + ++countline; + } + + nll += nllline; + counttotal += countline; + + // perplexity is e^(average negative log-likelihood) + printf("%lu\t%.8lf\t%.8lf\n", i + 1, std::exp(nllline/countline), std::exp(nll / counttotal) ); + fflush(stdout); + } + + printf("\n"); +} + int main(int argc, char ** argv) { gpt_params params; @@ -168,8 +245,12 @@ int main(int argc, char ** argv) { params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info()); } - perplexity(ctx, params); - + if( params.perplexity_lines ) { + perplexity_lines(ctx, params); + } else { + perplexity(ctx, params); + } + llama_print_timings(ctx); llama_free(ctx); llama_free_model(model);