llama : update llama_kv_self API
ggml-ci
This commit is contained in:
parent
fd05ab87aa
commit
17b363afd3
30 changed files with 387 additions and 205 deletions
|
@ -952,9 +952,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
return iparams;
|
||||
}
|
||||
|
||||
llama_kv_cache * kv = llama_get_kv_cache(lctx);
|
||||
|
||||
if (params.ctx_shift && !llama_kv_cache_can_shift(kv)) {
|
||||
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
|
||||
LOG_WRN("%s: KV cache shifting is not supported for this model, disabling KV cache shifting\n", __func__);
|
||||
params.ctx_shift = false;
|
||||
}
|
||||
|
@ -1059,7 +1057,7 @@ struct common_init_result common_init_from_params(common_params & params) {
|
|||
if (llama_model_has_decoder(model)) {
|
||||
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
|
||||
}
|
||||
llama_kv_cache_clear(kv);
|
||||
llama_kv_self_clear(lctx);
|
||||
llama_synchronize(lctx);
|
||||
llama_perf_context_reset(lctx);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue