examples : replace llama_kv_cache_seq_* with llama_past_seq_*

This commit is contained in:
Francis Couture-Harpin 2024-06-10 14:44:42 -04:00
parent 372482dffe
commit 43d8d4bf9e
23 changed files with 125 additions and 112 deletions

View file

@ -299,6 +299,10 @@ int main(int argc, char ** argv) {
}
n_matching_session_tokens++;
}
// remove any "future" tokens that we might have inherited from the previous session
n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1);
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
LOG_TEE("%s: using full prompt from session file\n", __func__);
} else if (n_matching_session_tokens >= embd_inp.size()) {
@ -310,9 +314,6 @@ int main(int argc, char ** argv) {
LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n",
__func__, n_matching_session_tokens, embd_inp.size());
}
// remove any "future" tokens that we might have inherited from the previous session
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
}
LOGLN(
@ -325,6 +326,8 @@ int main(int argc, char ** argv) {
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);
session_tokens.resize(embd_inp.size() - 1);
} else {
session_tokens.resize(n_matching_session_tokens);
}
// number of tokens to keep when resetting context
@ -535,8 +538,8 @@ int main(int argc, char ** argv) {
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
llama_past_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_past_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
n_past -= n_discard;
@ -563,9 +566,9 @@ int main(int argc, char ** argv) {
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd);
llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
n_past -= bd;
@ -579,6 +582,8 @@ int main(int argc, char ** argv) {
if (n_session_consumed < (int) session_tokens.size()) {
size_t i = 0;
for ( ; i < embd.size(); i++) {
// TODO: are the session tokens guaranteed to all be matching here?
// Should n_matching_session_tokens be re-used instead?
if (embd[i] != session_tokens[n_session_consumed]) {
session_tokens.resize(n_session_consumed);
break;