From 8128fa0bd3142e1e9f77970e977100977b898a08 Mon Sep 17 00:00:00 2001 From: Iwan Kawrakow Date: Mon, 22 Jan 2024 17:11:57 +0200 Subject: [PATCH] perplexity: add top-token probability --- examples/perplexity/perplexity.cpp | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 1b7f85f49..e7e9aa775 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -222,13 +222,18 @@ struct kl_divergence_result { double sum_kld2 = 0; double sum_nll_diff = 0; double sum_nll_diff2 = 0; + size_t n_same_top = 0; size_t count = 0; }; static void log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) { float max_logit = logits[0]; + int imax = 0; for (int i = 1; i < n_vocab; ++i) { - max_logit = std::max(max_logit, logits[i]); + if (logits[i] > max_logit) { + max_logit = logits[i]; + imax = i; + } } double sum_exp = 0.0; for (int i = 0; i < n_vocab; ++i) { @@ -247,8 +252,14 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base kld.sum_nll_diff2 += nll*nll; max_logit += log_sum_exp; double sum = 0; + int imax_base = -1; + float p_log_base_max = 0; for (int i = 0; i < n_vocab; ++i) { const float p_log_base = scale*base_log_prob[i] + min_log_prob; + if (i == 0 || p_log_base > p_log_base_max) { + p_log_base_max = p_log_base; + imax_base = i; + } if (p_log_base > -16.f) { const float p_base = expf(p_log_base); sum += p_base * (p_log_base - logits[i] + max_logit); @@ -257,6 +268,7 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base kld.sum_kld += sum; kld.sum_kld2 += sum*sum; ++kld.count; + if (imax == imax_base) ++kld.n_same_top; } static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, @@ -276,6 +288,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens kld.sum_kld2 += local_kld.sum_kld2; kld.sum_nll_diff += local_kld.sum_nll_diff; kld.sum_nll_diff2 += local_kld.sum_nll_diff2; + kld.n_same_top += local_kld.n_same_top; kld.count += local_kld.count; break; } @@ -1705,7 +1718,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { } fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); - printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence\n"); + printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence Same top\n"); } const int first = n_ctx/2; @@ -1716,9 +1729,12 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count); auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); + auto p_top = 1.*kld.n_same_top/kld.count; + auto d_p_top = sqrt(p_top*(1 - p_top)/(kld.count - 1)); - printf("%4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf\n", i+1, exp(ppl.first), - log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second); + printf("%4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf %.5f ± %.5f\n", i+1, exp(ppl.first), + log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second, + p_top, d_p_top); fflush(stdout);