diff --git a/main.cpp b/main.cpp index c623b8b61..a7940d088 100644 --- a/main.cpp +++ b/main.cpp @@ -527,7 +527,8 @@ bool llama_eval( const int n_past, const std::vector & embd_inp, std::vector & embd_w, - size_t & mem_per_token) { + size_t & mem_per_token, + bool return_all_logits = false) { const int N = embd_inp.size(); const auto & hparams = model.hparams; @@ -733,9 +734,14 @@ bool llama_eval( //embd_w.resize(n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); - // return result for just the last token - embd_w.resize(n_vocab); - memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + if (return_all_logits) { + embd_w.resize(n_vocab * N); + memcpy(embd_w.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); + } else { + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + } if (mem_per_token == 0) { mem_per_token = ggml_used_mem(ctx0)/N; @@ -769,6 +775,7 @@ void perplexity(const gpt_vocab &vocab, const llama_model &model, const gpt_para // Output: `perplexity: 13.5106 [114/114]` std::vector tokens = ::llama_tokenize(vocab, params.prompt, true); + int count = 0; double nll = 0.0; int seq_count = tokens.size() / params.n_ctx; for (int i = 0; i < seq_count; ++i) { @@ -776,15 +783,34 @@ void perplexity(const gpt_vocab &vocab, const llama_model &model, const gpt_para int end = start + params.n_ctx - 1; std::vector embd(tokens.begin() + start, tokens.begin() + end); std::vector logits; - if (!llama_eval(model, params.n_threads, 0, embd, logits, mem_per_token)) { + if (!llama_eval(model, params.n_threads, 0, embd, logits, mem_per_token, true)) { fprintf(stderr, "Failed to predict\n"); return; } - // Calculate probability of next token, given the previous ones. - double prob = softmax(logits)[tokens[end]]; - nll += -std::log(prob); + // We get the logits for all the tokens in the context window (params.n_ctx) + // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, + // calculate the perplexity over the last half the window (so the model always has + // some context to predict the token). + // + // We rely on the fact that attention in the forward pass only looks at previous + // tokens here, so the logits returned for each token are an accurate representation + // of what the model would have predicted at that point. + // + // Example, we have a context window of 512, we will compute perplexity for each of the + // last 256 tokens. Then, we split the input up into context window size chunks to + // process the entire prompt. + for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) { + // Calculate probability of next token, given the previous ones. + int n_vocab = model.hparams.n_vocab; + std::vector tok_logits( + logits.begin() + j * n_vocab, + logits.begin() + (j + 1) * n_vocab); + double prob = softmax(tok_logits)[tokens[start + j + 1]]; + nll += -std::log(prob); + ++count; + } // perplexity is e^(average negative log-likelihood) - printf("perplexity: %.4lf [%d/%d] \r", std::exp(nll / (i + 1)), i + 1, seq_count); + printf("perplexity: %.4lf [%d/%d] \r", std::exp(nll / count), i + 1, seq_count); fflush(stdout); } printf("\n");