diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index e7e9aa775..4afefe1c4 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -226,7 +226,7 @@ struct kl_divergence_result { size_t count = 0; }; -static void log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) { +static double log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) { float max_logit = logits[0]; int imax = 0; for (int i = 1; i < n_vocab; ++i) { @@ -269,14 +269,16 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base kld.sum_kld2 += sum*sum; ++kld.count; if (imax == imax_base) ++kld.n_same_top; + return sum; } static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token, - std::vector & workers, const std::vector & base_log_probs, kl_divergence_result & kld) { + std::vector & workers, const std::vector & base_log_probs, kl_divergence_result & kld, + float * kld_values) { std::mutex mutex; const int nv = 2*((n_vocab + 1)/2) + 4; int counter = 0; - auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv] () { + auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values] () { kl_divergence_result local_kld; while (true) { std::unique_lock lock(mutex); @@ -293,7 +295,8 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens break; } lock.unlock(); - log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld); + double v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld); + kld_values[i] = (float)v; } }; for (auto & w : workers) { @@ -1628,7 +1631,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { in.read((char *)&n_vocab, sizeof(n_vocab)); in.read((char *)&n_chunk, sizeof(n_chunk)); if (in.fail()) { - fprintf(stderr, "%s: failed rwading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str()); + fprintf(stderr, "%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str()); return; } if (n_vocab != llama_n_vocab(llama_get_model(ctx))) { @@ -1647,6 +1650,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); + std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector logits; if (num_batches > 1) { logits.reserve(n_ctx * n_vocab); @@ -1665,6 +1669,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { }; kl_divergence_result kld; + auto kld_ptr = kld_values.data(); for (int i = 0; i < n_chunk; ++i) { const int start = i * n_ctx; @@ -1724,7 +1729,8 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { const int first = n_ctx/2; const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, log_probs_uint16, kld); + workers, log_probs_uint16, kld, kld_ptr); + kld_ptr += n_ctx - 1 - first; auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count); @@ -1742,6 +1748,25 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { } printf("\n"); + if (kld.count < 100) return; // we do not wish to do statistics on so few values + + std::sort(kld_values.begin(), kld_values.end()); + + printf("===== KL-divergence statistics\n"); + auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); + printf("Average: %10.6f ±%10.6lf\n", kl_div.first, kl_div.second); + auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1]) + : kld_values[kld_values.size()/2]; + printf("Median : %10.6f\n", kld_median); + printf("Minimum: %10.6f\n", kld_values.front()); + printf("Maximum: %10.6f\n", kld_values.back()); + const int n_1percent = nearest_int(0.01f*kld_values.size()); + printf("KLD_01 : %10.6f\n", kld_values[n_1percent]); + printf("KLD_99 : %10.6f\n", kld_values[kld_values.size()-1-n_1percent]); + const int n_5percent = nearest_int(0.05f*kld_values.size()); + printf("KLD_05 : %10.6f\n", kld_values[n_5percent]); + printf("KLD_95 : %10.6f\n", kld_values[kld_values.size()-1-n_5percent]); + } int main(int argc, char ** argv) {