update from merge

This commit is contained in:
Gary Linscott 2023-03-25 13:30:40 -07:00
parent c3d3cd2d45
commit 7392ad629d

View file

@ -26,17 +26,26 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
int count = 0;
double nll = 0.0;
int seq_count = tokens.size() / params.n_ctx;
int n_vocab = llama_n_vocab(ctx);
fprintf(stderr, "%s : calculating perplexity over %d chunks\n", __func__, seq_count);
fprintf(stderr, "%s : calculating perplexity over %d chunks, batch_size=%d\n", __func__, seq_count, params.n_batch);
for (int i = 0; i < seq_count; ++i) {
int start = i * params.n_ctx;
int end = start + params.n_ctx - 1;
std::vector<llama_token> embd(tokens.begin() + start, tokens.begin() + end);
std::vector<float> logits;
int num_batches = (params.n_ctx + params.n_batch - 1) / params.n_batch;
auto start_t = std::chrono::high_resolution_clock::now();
if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
for (int j = 0; j < num_batches; ++j) {
int batch_start = start + j * params.n_batch;
int batch_size = std::min(end - batch_start, params.n_batch);
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * params.n_batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return;
}
auto batch_logits = llama_get_logits(ctx);
logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab);
}
auto end_t = std::chrono::high_resolution_clock::now();
if (i == 0) {
@ -56,13 +65,11 @@ void perplexity(llama_context * ctx, const gpt_params & params) {
// last 256 tokens. Then, we split the input up into context window size chunks to
// process the entire prompt.
auto logits = llama_get_logits(ctx);
for (int j = params.n_ctx / 2; j < params.n_ctx - 1; ++j) {
// Calculate probability of next token, given the previous ones.
int n_vocab = llama_n_vocab(ctx);
std::vector<float> tok_logits(
logits + j * n_vocab,
logits + (j + 1) * n_vocab);
logits.begin() + j * n_vocab,
logits.begin() + (j + 1) * n_vocab);
double prob = softmax(tok_logits)[tokens[start + j + 1]];
nll += -std::log(prob);
++count;