perplexity : add comments
This commit is contained in:
parent
0e4e58ff1b
commit
30ebd94723
1 changed files with 5 additions and 4 deletions
|
@ -583,9 +583,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
|
||||||
// batch as much tasks as possible into the available context
|
// batch as much tasks as possible into the available context
|
||||||
// each task has 4 unique seuqnce ids
|
// each task has 4 unique seuqnce ids - one for each ending
|
||||||
// the common prefix is shared among the 4 sequences to save tokens
|
// the common prefix is shared among the 4 sequences to save tokens
|
||||||
// we extract logits only from the last common token and all ending tokens of each sequence
|
// we extract logits only from the last common token and from all ending tokens of each sequence
|
||||||
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
|
while (n_cur + (int) hs_data[i1].required_tokens <= n_ctx) {
|
||||||
auto & hs_cur = hs_data[i1];
|
auto & hs_cur = hs_data[i1];
|
||||||
|
|
||||||
|
@ -598,7 +598,7 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
for (size_t i = 0; i < hs_cur.common_prefix; ++i) {
|
||||||
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
llama_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false);
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix
|
||||||
|
|
||||||
for (int s = 0; s < 4; ++s) {
|
for (int s = 0; s < 4; ++s) {
|
||||||
for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
|
for (size_t i = hs_cur.common_prefix; i < hs_cur.seq_tokens[s].size(); ++i) {
|
||||||
|
@ -622,12 +622,13 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
|
// decode all tasks [i0, i1)
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
fprintf(stderr, "%s: llama_decode() failed\n", __func__);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// evaluate the computed tasks
|
// compute the logprobs for each ending of the decoded tasks
|
||||||
for (size_t i = i0; i < i1; ++i) {
|
for (size_t i = i0; i < i1; ++i) {
|
||||||
auto & hs_cur = hs_data[i];
|
auto & hs_cur = hs_data[i];
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue