examples : replace llama_kv_cache_seq_* with llama_past_seq_*
This commit is contained in:
parent
372482dffe
commit
43d8d4bf9e
23 changed files with 125 additions and 112 deletions
|
@ -1107,7 +1107,7 @@ struct server_context {
|
|||
LOG_VERBOSE("clearing KV cache", {});
|
||||
|
||||
// clear the entire KV cache
|
||||
llama_kv_cache_clear(ctx);
|
||||
llama_past_clear(ctx);
|
||||
clean_kv_cache = false;
|
||||
}
|
||||
|
||||
|
@ -1151,7 +1151,7 @@ struct server_context {
|
|||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i <= params.n_parallel; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
|
||||
llama_past_seq_cp(ctx, 0, i, -1, -1);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1824,7 +1824,7 @@ struct server_context {
|
|||
|
||||
// Erase token cache
|
||||
const size_t n_erased = slot->cache_tokens.size();
|
||||
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
|
||||
llama_past_seq_rm(ctx, slot->id + 1, -1, -1);
|
||||
slot->cache_tokens.clear();
|
||||
|
||||
server_task_result result;
|
||||
|
@ -1939,8 +1939,8 @@ struct server_context {
|
|||
{"n_cache_tokens", slot.cache_tokens.size()}
|
||||
});
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
llama_past_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
|
||||
llama_past_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
|
||||
|
||||
if (slot.params.cache_prompt) {
|
||||
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
|
||||
|
@ -2155,23 +2155,28 @@ struct server_context {
|
|||
}
|
||||
|
||||
// keep only the common part
|
||||
int p0 = (int) system_tokens.size() + slot.n_past;
|
||||
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
|
||||
// could not partially delete (likely using a non-Transformer model)
|
||||
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
|
||||
llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past;
|
||||
|
||||
p0 = (int) system_tokens.size();
|
||||
if (p0 != 0) {
|
||||
// copy over the system prompt when there is one
|
||||
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
|
||||
// for recurrent and hybrid models, sometimes it goes back further than asked
|
||||
llama_pos new_p0 = llama_past_seq_rm(ctx, slot.id + 1, p0, -1);
|
||||
|
||||
if (new_p0 < p0) {
|
||||
GGML_ASSERT(new_p0 >= (llama_pos) system_tokens.size());
|
||||
|
||||
slot.n_past -= p0 - new_p0;
|
||||
if (slot.ga_i > 0) {
|
||||
// TODO: test with an hybrid model (e.g. Jamba)
|
||||
slot.n_past_se -= p0 - new_p0;
|
||||
}
|
||||
|
||||
// there is no common part left (except for the system prompt)
|
||||
slot.n_past = 0;
|
||||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
// TODO: is the system prompt ever in the sampling context?
|
||||
// TODO: find a way to avoid rolling back the sampling context twice
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||
}
|
||||
|
||||
p0 = new_p0;
|
||||
}
|
||||
|
||||
// remove the non-common part from the cache
|
||||
|
@ -2273,9 +2278,9 @@ struct server_context {
|
|||
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
|
||||
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
|
||||
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
||||
llama_past_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
|
||||
llama_past_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
|
||||
llama_past_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
|
||||
|
||||
slot.n_past_se -= bd;
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue