perplexity : add comments

This commit is contained in:
Georgi Gerganov 2024-01-18 14:20:11 +02:00
parent 0e4e58ff1b
commit 30ebd94723

View file

@ -583,9 +583,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
llama_batch_clear(batch); llama_batch_clear(batch);
// batch as much tasks as possible into the available context // batch as much tasks as possible into the available context
// each task has 4 unique seuqnce ids // each task has 4 unique seuqnce ids - one for each ending
// the common prefix is shared among the 4 sequences to save tokens // the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and all ending tokens of each sequence // we extract logits only from the last common token and from all ending tokens of each sequence
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) { while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
auto & hs_cur = hs_data[i1]; auto & hs_cur = hs_data[i1];
@ -598,7 +598,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
for (size_t i = 0; i < hs_cur.common_prefix; ++i) { for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
} }
batch.logits[batch.n_tokens - 1] = true; batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
for (int s = 0; s < 4; ++s) { for (int s = 0; s < 4; ++s) {
for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) { for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
@ -622,12 +622,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
llama_kv_cache_clear(ctx); llama_kv_cache_clear(ctx);
// decode all tasks [i0, i1)
if (llama_decode(ctx, batch) != 0) { if (llama_decode(ctx, batch) != 0) {
fprintf(stderr, "%s: llama_decode() failed\n", __func__); fprintf(stderr, "%s: llama_decode() failed\n", __func__);
return; return;
} }
// evaluate the computed tasks // compute the logprobs for each ending of the decoded tasks
for (size_t i = i0; i < i1; ++i) { for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i]; auto & hs_cur = hs_data[i];