perplexity : add comments

This commit is contained in:
Georgi Gerganov 2024-01-18 14:20:11 +02:00
parent 0e4e58ff1b
commit 30ebd94723

View file

@ -583,9 +583,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
llama_batch_clear(batch);
// batch as much tasks as possible into the available context
// each task has 4 unique seuqnce ids
// each task has 4 unique seuqnce ids - one for each ending
// the common prefix is shared among the 4 sequences to save tokens
// we extract logits only from the last common token and all ending tokens of each sequence
// we extract logits only from the last common token and from all ending tokens of each sequence
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
auto & hs_cur = hs_data[i1];
@ -598,7 +598,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
}
batch.logits[batch.n_tokens - 1] = true;
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
for (int s = 0; s < 4; ++s) {
for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
@ -622,12 +622,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
llama_kv_cache_clear(ctx);
// decode all tasks [i0, i1)
if (llama_decode(ctx, batch) != 0) {
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
return;
}
// evaluate the computed tasks
// compute the logprobs for each ending of the decoded tasks
for (size_t i = i0; i < i1; ++i) {
auto & hs_cur = hs_data[i];