speculative : fix off-by-one for n_drafted
This commit is contained in:
parent
373d782d42
commit
f07cd35da4
1 changed files with 5 additions and 2 deletions
|
@ -336,7 +336,7 @@ int main(int argc, char ** argv) {
|
|||
llama_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true);
|
||||
|
||||
// no need to evaluate the last drafted token, since we won't use the result
|
||||
if (batch_tgt.n_tokens == n_draft) {
|
||||
if (batch_tgt.n_tokens > n_draft) {
|
||||
drafts[s].drafting = false;
|
||||
continue;
|
||||
}
|
||||
|
@ -358,11 +358,14 @@ int main(int argc, char ** argv) {
|
|||
++n_past_cur;
|
||||
++n_drafted;
|
||||
|
||||
if (batch_tgt.n_tokens >= n_draft) {
|
||||
if (batch_tgt.n_tokens > n_draft) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// account for the last drafted token that we didn't evaluate
|
||||
++n_drafted;
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
{
|
||||
llama_kv_cache_seq_keep(ctx_tgt, 0);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue