examples : replace llama_kv_cache_seq_* with llama_past_seq_*

This commit is contained in:
Francis Couture-Harpin 2024-06-10 14:44:42 -04:00
parent 372482dffe
commit 43d8d4bf9e
23 changed files with 125 additions and 112 deletions

View file

@ -1107,7 +1107,7 @@ struct server_context {
LOG_VERBOSE("clearing KV cache", {});
// clear the entire KV cache
llama_kv_cache_clear(ctx);
llama_past_clear(ctx);
clean_kv_cache = false;
}
@ -1151,7 +1151,7 @@ struct server_context {
// assign the system KV cache to all parallel sequences
for (int32_t i = 1; i <= params.n_parallel; ++i) {
llama_kv_cache_seq_cp(ctx, 0, i, -1, -1);
llama_past_seq_cp(ctx, 0, i, -1, -1);
}
}
@ -1824,7 +1824,7 @@ struct server_context {
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
llama_past_seq_rm(ctx, slot->id + 1, -1, -1);
slot->cache_tokens.clear();
server_task_result result;
@ -1939,8 +1939,8 @@ struct server_context {
{"n_cache_tokens", slot.cache_tokens.size()}
});
llama_kv_cache_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_kv_cache_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
llama_past_seq_rm (ctx, slot.id + 1, n_keep , n_keep + n_discard);
llama_past_seq_add(ctx, slot.id + 1, n_keep + n_discard, system_tokens.size() + slot.n_past, -n_discard);
if (slot.params.cache_prompt) {
for (size_t i = n_keep + n_discard; i < slot.cache_tokens.size(); i++) {
@ -2155,23 +2155,28 @@ struct server_context {
}
// keep only the common part
int p0 = (int) system_tokens.size() + slot.n_past;
if (!llama_kv_cache_seq_rm(ctx, slot.id + 1, p0, -1)) {
// could not partially delete (likely using a non-Transformer model)
llama_kv_cache_seq_rm(ctx, slot.id + 1, -1, -1);
llama_pos p0 = (llama_pos) system_tokens.size() + slot.n_past;
p0 = (int) system_tokens.size();
if (p0 != 0) {
// copy over the system prompt when there is one
llama_kv_cache_seq_cp(ctx, 0, slot.id + 1, -1, -1);
// for recurrent and hybrid models, sometimes it goes back further than asked
llama_pos new_p0 = llama_past_seq_rm(ctx, slot.id + 1, p0, -1);
if (new_p0 < p0) {
GGML_ASSERT(new_p0 >= (llama_pos) system_tokens.size());
slot.n_past -= p0 - new_p0;
if (slot.ga_i > 0) {
// TODO: test with an hybrid model (e.g. Jamba)
slot.n_past_se -= p0 - new_p0;
}
// there is no common part left (except for the system prompt)
slot.n_past = 0;
slot.n_past_se = 0;
slot.ga_i = 0;
// TODO: is the system prompt ever in the sampling context?
// TODO: find a way to avoid rolling back the sampling context twice
llama_sampling_reset(slot.ctx_sampling);
// push the prompt into the sampling context (do not apply grammar)
for (int i = 0; i < slot.n_past; ++i) {
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
}
p0 = new_p0;
}
// remove the non-common part from the cache
@ -2273,9 +2278,9 @@ struct server_context {
LOG_TEE("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n, (slot.ga_i + ib * bd) / slot.ga_n, (slot.ga_i + ib * bd + slot.ga_w) / slot.ga_n);
LOG_TEE("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd, slot.ga_i + ib * bd + slot.ga_w + dd, slot.n_past_se + ib * bd + dd);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
llama_kv_cache_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_kv_cache_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
llama_past_seq_add(ctx, slot.id + 1, slot.ga_i, slot.n_past_se, ib * bd);
llama_past_seq_div(ctx, slot.id + 1, slot.ga_i + ib * bd, slot.ga_i + ib * bd + slot.ga_w, slot.ga_n);
llama_past_seq_add(ctx, slot.id + 1, slot.ga_i + ib * bd + slot.ga_w, slot.n_past_se + ib * bd, dd);
slot.n_past_se -= bd;