llama : add struct llama_kv_cache (wip) [no ci]

This commit is contained in:
Georgi Gerganov 2025-01-13 14:13:11 +02:00
parent 178a7eb952
commit f78b396ee7
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
8 changed files with 428 additions and 415 deletions

View file

@ -952,7 +952,9 @@ struct common_init_result common_init_from_params(common_params & params) {
return iparams;
}
if (params.ctx_shift && !llama_kv_cache_can_shift(lctx)) {
llama_kv_cache * kv = llama_get_kv_cache(lctx);
if (params.ctx_shift && !llama_kv_cache_can_shift(kv)) {
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
params.ctx_shift = false;
}
@ -1057,7 +1059,7 @@ struct common_init_result common_init_from_params(common_params & params) {
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
}
llama_kv_cache_clear(lctx);
llama_kv_cache_clear(kv);
llama_synchronize(lctx);
llama_perf_context_reset(lctx);
}

View file

@ -171,8 +171,10 @@ llama_tokens common_speculative_gen_draft(
llama_tokens result;
result.reserve(params.n_draft);
llama_kv_cache * kv = llama_get_kv_cache(ctx);
if (reuse_n == 0) {
llama_kv_cache_clear(ctx);
llama_kv_cache_clear(kv);
prompt.clear();
} else {
@ -191,14 +193,14 @@ llama_tokens common_speculative_gen_draft(
}
if (reuse_i > 0) {
llama_kv_cache_seq_rm (ctx, 0, 0, reuse_i);
llama_kv_cache_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
llama_kv_cache_seq_rm (kv, 0, 0, reuse_i);
llama_kv_cache_seq_add(kv, 0, reuse_i, -1, -reuse_i);
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
}
if (reuse_n < (int) prompt.size()) {
llama_kv_cache_seq_rm (ctx, 0, reuse_n, -1);
llama_kv_cache_seq_rm (kv, 0, reuse_n, -1);
prompt.erase(prompt.begin() + reuse_n, prompt.end());
}