perplexity: add additional KL-divergence statistics
This commit is contained in:
parent
8128fa0bd3
commit
150af7ecf7
1 changed files with 31 additions and 6 deletions
|
@ -226,7 +226,7 @@ struct kl_divergence_result {
|
||||||
size_t count = 0;
|
size_t count = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
static void log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
|
static double log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
|
||||||
float max_logit = logits[0];
|
float max_logit = logits[0];
|
||||||
int imax = 0;
|
int imax = 0;
|
||||||
for (int i = 1; i < n_vocab; ++i) {
|
for (int i = 1; i < n_vocab; ++i) {
|
||||||
|
@ -269,14 +269,16 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
|
||||||
kld.sum_kld2 += sum*sum;
|
kld.sum_kld2 += sum*sum;
|
||||||
++kld.count;
|
++kld.count;
|
||||||
if (imax == imax_base) ++kld.n_same_top;
|
if (imax == imax_base) ++kld.n_same_top;
|
||||||
|
return sum;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
|
static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
|
||||||
std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld) {
|
std::vector<std::thread> & workers, const std::vector<uint16_t> & base_log_probs, kl_divergence_result & kld,
|
||||||
|
float * kld_values) {
|
||||||
std::mutex mutex;
|
std::mutex mutex;
|
||||||
const int nv = 2*((n_vocab + 1)/2) + 4;
|
const int nv = 2*((n_vocab + 1)/2) + 4;
|
||||||
int counter = 0;
|
int counter = 0;
|
||||||
auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv] () {
|
auto compute = [&mutex, &counter, &base_log_probs, &kld, n_vocab, logits, tokens, n_token, nv, kld_values] () {
|
||||||
kl_divergence_result local_kld;
|
kl_divergence_result local_kld;
|
||||||
while (true) {
|
while (true) {
|
||||||
std::unique_lock<std::mutex> lock(mutex);
|
std::unique_lock<std::mutex> lock(mutex);
|
||||||
|
@ -293,7 +295,8 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
lock.unlock();
|
lock.unlock();
|
||||||
log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
|
double v = log_softmax(n_vocab, logits + i*n_vocab, base_log_probs.data() + i*nv, tokens[i+1], local_kld);
|
||||||
|
kld_values[i] = (float)v;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
for (auto & w : workers) {
|
for (auto & w : workers) {
|
||||||
|
@ -1628,7 +1631,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||||
in.read((char *)&n_vocab, sizeof(n_vocab));
|
in.read((char *)&n_vocab, sizeof(n_vocab));
|
||||||
in.read((char *)&n_chunk, sizeof(n_chunk));
|
in.read((char *)&n_chunk, sizeof(n_chunk));
|
||||||
if (in.fail()) {
|
if (in.fail()) {
|
||||||
fprintf(stderr, "%s: failed rwading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
|
fprintf(stderr, "%s: failed reading n_vocab, n_chunk from %s\n", __func__, params.logits_file.c_str());
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
|
if (n_vocab != llama_n_vocab(llama_get_model(ctx))) {
|
||||||
|
@ -1647,6 +1650,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||||
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx));
|
||||||
|
|
||||||
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
|
std::vector<uint16_t> log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv);
|
||||||
|
std::vector<float> kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk);
|
||||||
std::vector<float> logits;
|
std::vector<float> logits;
|
||||||
if (num_batches > 1) {
|
if (num_batches > 1) {
|
||||||
logits.reserve(n_ctx * n_vocab);
|
logits.reserve(n_ctx * n_vocab);
|
||||||
|
@ -1665,6 +1669,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||||
};
|
};
|
||||||
|
|
||||||
kl_divergence_result kld;
|
kl_divergence_result kld;
|
||||||
|
auto kld_ptr = kld_values.data();
|
||||||
|
|
||||||
for (int i = 0; i < n_chunk; ++i) {
|
for (int i = 0; i < n_chunk; ++i) {
|
||||||
const int start = i * n_ctx;
|
const int start = i * n_ctx;
|
||||||
|
@ -1724,7 +1729,8 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||||
const int first = n_ctx/2;
|
const int first = n_ctx/2;
|
||||||
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx);
|
||||||
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first,
|
||||||
workers, log_probs_uint16, kld);
|
workers, log_probs_uint16, kld, kld_ptr);
|
||||||
|
kld_ptr += n_ctx - 1 - first;
|
||||||
|
|
||||||
auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
||||||
auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
|
auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
|
||||||
|
@ -1742,6 +1748,25 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||||
}
|
}
|
||||||
printf("\n");
|
printf("\n");
|
||||||
|
|
||||||
|
if (kld.count < 100) return; // we do not wish to do statistics on so few values
|
||||||
|
|
||||||
|
std::sort(kld_values.begin(), kld_values.end());
|
||||||
|
|
||||||
|
printf("===== KL-divergence statistics\n");
|
||||||
|
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
||||||
|
printf("Average: %10.6f ±%10.6lf\n", kl_div.first, kl_div.second);
|
||||||
|
auto kld_median = kld_values.size()%2 == 0 ? 0.5f*(kld_values[kld_values.size()/2] + kld_values[kld_values.size()/2-1])
|
||||||
|
: kld_values[kld_values.size()/2];
|
||||||
|
printf("Median : %10.6f\n", kld_median);
|
||||||
|
printf("Minimum: %10.6f\n", kld_values.front());
|
||||||
|
printf("Maximum: %10.6f\n", kld_values.back());
|
||||||
|
const int n_1percent = nearest_int(0.01f*kld_values.size());
|
||||||
|
printf("KLD_01 : %10.6f\n", kld_values[n_1percent]);
|
||||||
|
printf("KLD_99 : %10.6f\n", kld_values[kld_values.size()-1-n_1percent]);
|
||||||
|
const int n_5percent = nearest_int(0.05f*kld_values.size());
|
||||||
|
printf("KLD_05 : %10.6f\n", kld_values[n_5percent]);
|
||||||
|
printf("KLD_95 : %10.6f\n", kld_values[kld_values.size()-1-n_5percent]);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue