llama : fix handling of "future" tokens when loading sessions
This commit is contained in:
parent
0f332a9104
commit
337120cc0d
6 changed files with 41 additions and 40 deletions
|
@ -543,6 +543,9 @@ int main(int argc, char ** argv) {
|
|||
if (i > 0) {
|
||||
embd.erase(embd.begin(), embd.begin() + i);
|
||||
}
|
||||
|
||||
// remove any "future" tokens that we might have inherited from the session from the KV cache
|
||||
llama_kv_cache_tokens_rm(ctx, n_past, -1);
|
||||
}
|
||||
|
||||
// evaluate tokens in batches
|
||||
|
|
|
@ -332,7 +332,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, -1);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
|
|
|
@ -448,7 +448,7 @@ struct llama_server_context
|
|||
n_past = common_part(embd, prompt_tokens);
|
||||
|
||||
// since #3228 we now have to manually manage the KV cache
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, -1);
|
||||
|
||||
embd = prompt_tokens;
|
||||
if (n_past == num_prompt_tokens)
|
||||
|
|
|
@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
|
|||
LOG("out of drafted tokens\n");
|
||||
}
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, -1);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0));
|
||||
++n_past_dft;
|
||||
|
||||
|
@ -257,7 +257,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate the drafted token on the draft model
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, -1);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0));
|
||||
++n_past_cur;
|
||||
|
||||
|
@ -267,7 +267,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
|
||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, -1);
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0));
|
||||
++n_past_tgt;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue