perplexity: add top-token probability
This commit is contained in:
parent
6f9939d119
commit
8128fa0bd3
1 changed files with 20 additions and 4 deletions
|
@ -222,13 +222,18 @@ struct kl_divergence_result {
|
||||||
double sum_kld2 = 0;
|
double sum_kld2 = 0;
|
||||||
double sum_nll_diff = 0;
|
double sum_nll_diff = 0;
|
||||||
double sum_nll_diff2 = 0;
|
double sum_nll_diff2 = 0;
|
||||||
|
size_t n_same_top = 0;
|
||||||
size_t count = 0;
|
size_t count = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
static void log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
|
static void log_softmax(int n_vocab, const float * logits, const uint16_t * base_log_prob, int tok, kl_divergence_result & kld) {
|
||||||
float max_logit = logits[0];
|
float max_logit = logits[0];
|
||||||
|
int imax = 0;
|
||||||
for (int i = 1; i < n_vocab; ++i) {
|
for (int i = 1; i < n_vocab; ++i) {
|
||||||
max_logit = std::max(max_logit, logits[i]);
|
if (logits[i] > max_logit) {
|
||||||
|
max_logit = logits[i];
|
||||||
|
imax = i;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
double sum_exp = 0.0;
|
double sum_exp = 0.0;
|
||||||
for (int i = 0; i < n_vocab; ++i) {
|
for (int i = 0; i < n_vocab; ++i) {
|
||||||
|
@ -247,8 +252,14 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
|
||||||
kld.sum_nll_diff2 += nll*nll;
|
kld.sum_nll_diff2 += nll*nll;
|
||||||
max_logit += log_sum_exp;
|
max_logit += log_sum_exp;
|
||||||
double sum = 0;
|
double sum = 0;
|
||||||
|
int imax_base = -1;
|
||||||
|
float p_log_base_max = 0;
|
||||||
for (int i = 0; i < n_vocab; ++i) {
|
for (int i = 0; i < n_vocab; ++i) {
|
||||||
const float p_log_base = scale*base_log_prob[i] + min_log_prob;
|
const float p_log_base = scale*base_log_prob[i] + min_log_prob;
|
||||||
|
if (i == 0 || p_log_base > p_log_base_max) {
|
||||||
|
p_log_base_max = p_log_base;
|
||||||
|
imax_base = i;
|
||||||
|
}
|
||||||
if (p_log_base > -16.f) {
|
if (p_log_base > -16.f) {
|
||||||
const float p_base = expf(p_log_base);
|
const float p_base = expf(p_log_base);
|
||||||
sum += p_base * (p_log_base - logits[i] + max_logit);
|
sum += p_base * (p_log_base - logits[i] + max_logit);
|
||||||
|
@ -257,6 +268,7 @@ static void log_softmax(int n_vocab, const float * logits, const uint16_t * base
|
||||||
kld.sum_kld += sum;
|
kld.sum_kld += sum;
|
||||||
kld.sum_kld2 += sum*sum;
|
kld.sum_kld2 += sum*sum;
|
||||||
++kld.count;
|
++kld.count;
|
||||||
|
if (imax == imax_base) ++kld.n_same_top;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
|
static void process_logits(int n_vocab, const float * logits, const int * tokens, int n_token,
|
||||||
|
@ -276,6 +288,7 @@ static void process_logits(int n_vocab, const float * logits, const int * tokens
|
||||||
kld.sum_kld2 += local_kld.sum_kld2;
|
kld.sum_kld2 += local_kld.sum_kld2;
|
||||||
kld.sum_nll_diff += local_kld.sum_nll_diff;
|
kld.sum_nll_diff += local_kld.sum_nll_diff;
|
||||||
kld.sum_nll_diff2 += local_kld.sum_nll_diff2;
|
kld.sum_nll_diff2 += local_kld.sum_nll_diff2;
|
||||||
|
kld.n_same_top += local_kld.n_same_top;
|
||||||
kld.count += local_kld.count;
|
kld.count += local_kld.count;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -1705,7 +1718,7 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||||
}
|
}
|
||||||
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0);
|
||||||
|
|
||||||
printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence\n");
|
printf("\nchunk PPL ln(PPL(Q)/PPL(base)) KL-Divergence Same top\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
const int first = n_ctx/2;
|
const int first = n_ctx/2;
|
||||||
|
@ -1716,9 +1729,12 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) {
|
||||||
auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
auto ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count);
|
||||||
auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
|
auto log_ppl_ratio = mean_and_uncertainty(kld.sum_nll_diff, kld.sum_nll_diff2, kld.count);
|
||||||
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count);
|
||||||
|
auto p_top = 1.*kld.n_same_top/kld.count;
|
||||||
|
auto d_p_top = sqrt(p_top*(1 - p_top)/(kld.count - 1));
|
||||||
|
|
||||||
printf("%4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf\n", i+1, exp(ppl.first),
|
printf("%4d %10.4lf %10.5lf ± %10.5f %10.5f ± %10.5lf %.5f ± %.5f\n", i+1, exp(ppl.first),
|
||||||
log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second);
|
log_ppl_ratio.first, log_ppl_ratio.second, kl_div.first, kl_div.second,
|
||||||
|
p_top, d_p_top);
|
||||||
|
|
||||||
fflush(stdout);
|
fflush(stdout);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue