More accurate perplexity calculation - over all logits in the context window (so 512x more tokens!)

This commit is contained in:
Gary Linscott 2023-03-19 13:33:12 -07:00
parent e94bd9c7b9
commit 91d71fe0c1

View file

@ -527,7 +527,8 @@ bool llama_eval(
const int n_past, const int n_past,
const std::vector<gpt_vocab::id> & embd_inp, const std::vector<gpt_vocab::id> & embd_inp,
std::vector<float> & embd_w, std::vector<float> & embd_w,
size_t & mem_per_token) { size_t & mem_per_token,
bool return_all_logits = false) {
const int N = embd_inp.size(); const int N = embd_inp.size();
const auto & hparams = model.hparams; const auto & hparams = model.hparams;
@ -733,9 +734,14 @@ bool llama_eval(
//embd_w.resize(n_vocab*N); //embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);
// return result for just the last token if (return_all_logits) {
embd_w.resize(n_vocab); embd_w.resize(n_vocab * N);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); memcpy(embd_w.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
} else {
// return result for just the last token
embd_w.resize(n_vocab);
memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
}
if (mem_per_token == 0) { if (mem_per_token == 0) {
mem_per_token = ggml_used_mem(ctx0)/N; mem_per_token = ggml_used_mem(ctx0)/N;
@ -769,6 +775,7 @@ void perplexity(const gpt_vocab &vocab, const llama_model &model, const gpt_para
// Output: `perplexity: 13.5106 [114/114]` // Output: `perplexity: 13.5106 [114/114]`
std::vector<gpt_vocab::id> tokens = ::llama_tokenize(vocab, params.prompt, true); std::vector<gpt_vocab::id> tokens = ::llama_tokenize(vocab, params.prompt, true);
int count = 0;
double nll = 0.0; double nll = 0.0;
int seq_count = tokens.size() / params.n_ctx; int seq_count = tokens.size() / params.n_ctx;
for (int i = 0; i < seq_count; ++i) { for (int i = 0; i < seq_count; ++i) {
@ -776,15 +783,34 @@ void perplexity(const gpt_vocab &vocab, const llama_model &model, const gpt_para
int end = start + params.n_ctx - 1; int end = start + params.n_ctx - 1;
std::vector<gpt_vocab::id> embd(tokens.begin() + start, tokens.begin() + end); std::vector<gpt_vocab::id> embd(tokens.begin() + start, tokens.begin() + end);
std::vector<float> logits; std::vector<float> logits;
if (!llama_eval(model, params.n_threads, 0, embd, logits, mem_per_token)) { if (!llama_eval(model, params.n_threads, 0, embd, logits, mem_per_token, true)) {
fprintf(stderr, "Failed to predict\n"); fprintf(stderr, "Failed to predict\n");
return; return;
} }
// Calculate probability of next token, given the previous ones. // We get the logits for all the tokens in the context window (params.n_ctx)
double prob = softmax(logits)[tokens[end]]; // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity,
nll += -std::log(prob); // calculate the perplexity over the last half the window (so the model always has
// some context to predict the token).
//
// We rely on the fact that attention in the forward pass only looks at previous
// tokens here, so the logits returned for each token are an accurate representation
// of what the model would have predicted at that point.
//
// Example, we have a context window of 512, we will compute perplexity for each of the
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
int n_vocab = model.hparams.n_vocab;
std::vector<float> tok_logits(
logits.begin() + j * n_vocab,
logits.begin() + (j + 1) * n_vocab);
double prob = softmax(tok_logits)[tokens[start + j + 1]];
nll += -std::log(prob);
++count;
}
// perplexity is e^(average negative log-likelihood) // perplexity is e^(average negative log-likelihood)
printf("perplexity: %.4lf [%d/%d] \r", std::exp(nll / (i + 1)), i + 1, seq_count); printf("perplexity: %.4lf [%d/%d] \r", std::exp(nll / count), i + 1, seq_count);
fflush(stdout); fflush(stdout);
} }
printf("\n"); printf("\n");