examples : replace llama_kv_cache_seq_* with llama_past_seq_*
This commit is contained in:
parent
372482dffe
commit
43d8d4bf9e
23 changed files with 125 additions and 112 deletions
|
@ -299,6 +299,10 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
n_matching_session_tokens++;
|
||||
}
|
||||
|
||||
// remove any "future" tokens that we might have inherited from the previous session
|
||||
n_matching_session_tokens = llama_past_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
||||
|
||||
if (params.prompt.empty() && n_matching_session_tokens == embd_inp.size()) {
|
||||
LOG_TEE("%s: using full prompt from session file\n", __func__);
|
||||
} else if (n_matching_session_tokens >= embd_inp.size()) {
|
||||
|
@ -310,9 +314,6 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("%s: session file matches %zu / %zu tokens of prompt\n",
|
||||
__func__, n_matching_session_tokens, embd_inp.size());
|
||||
}
|
||||
|
||||
// remove any "future" tokens that we might have inherited from the previous session
|
||||
llama_kv_cache_seq_rm(ctx, -1, n_matching_session_tokens, -1);
|
||||
}
|
||||
|
||||
LOGLN(
|
||||
|
@ -325,6 +326,8 @@ int main(int argc, char ** argv) {
|
|||
LOGLN("recalculate the cached logits (do): session_tokens.resize( %zu )", embd_inp.size() - 1);
|
||||
|
||||
session_tokens.resize(embd_inp.size() - 1);
|
||||
} else {
|
||||
session_tokens.resize(n_matching_session_tokens);
|
||||
}
|
||||
|
||||
// number of tokens to keep when resetting context
|
||||
|
@ -535,8 +538,8 @@ int main(int argc, char ** argv) {
|
|||
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||
llama_past_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||
llama_past_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
|
@ -563,9 +566,9 @@ int main(int argc, char ** argv) {
|
|||
LOG("div: [%6d, %6d] / %6d -> [%6d, %6d]\n", ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n, (ga_i + ib*bd)/ga_n, (ga_i + ib*bd + ga_w)/ga_n);
|
||||
LOG("shift: [%6d, %6d] + %6d -> [%6d, %6d]\n", ga_i + ib*bd + ga_w, n_past + ib*bd, dd, ga_i + ib*bd + ga_w + dd, n_past + ib*bd + dd);
|
||||
|
||||
llama_kv_cache_seq_add(ctx, 0, ga_i, n_past, ib*bd);
|
||||
llama_kv_cache_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
|
||||
llama_kv_cache_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
|
||||
llama_past_seq_add(ctx, 0, ga_i, n_past, ib*bd);
|
||||
llama_past_seq_div(ctx, 0, ga_i + ib*bd, ga_i + ib*bd + ga_w, ga_n);
|
||||
llama_past_seq_add(ctx, 0, ga_i + ib*bd + ga_w, n_past + ib*bd, dd);
|
||||
|
||||
n_past -= bd;
|
||||
|
||||
|
@ -579,6 +582,8 @@ int main(int argc, char ** argv) {
|
|||
if (n_session_consumed < (int) session_tokens.size()) {
|
||||
size_t i = 0;
|
||||
for ( ; i < embd.size(); i++) {
|
||||
// TODO: are the session tokens guaranteed to all be matching here?
|
||||
// Should n_matching_session_tokens be re-used instead?
|
||||
if (embd[i] != session_tokens[n_session_consumed]) {
|
||||
session_tokens.resize(n_session_consumed);
|
||||
break;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue