fix context shifting

This commit is contained in:
Xuan Son Nguyen 2024-10-11 14:36:48 +02:00
parent 7740c969d0
commit 6a9769a260
3 changed files with 4 additions and 4 deletions

View file

@ -376,7 +376,7 @@ int main(int argc, char ** argv) {
n_past, n_left, n_ctx, params.n_keep, n_discard); n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1); llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard); llama_kv_cache_seq_add(ctx, 0, params.n_keep + 1 + n_discard, n_past + 1, -n_discard);
n_past -= n_discard; n_past -= n_discard;

View file

@ -582,7 +582,7 @@ int main(int argc, char ** argv) {
n_past, n_left, n_ctx, params.n_keep, n_discard); n_past, n_left, n_ctx, params.n_keep, n_discard);
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past + 1 , -n_discard);
n_past -= n_discard; n_past -= n_discard;

View file

@ -21134,7 +21134,7 @@ int32_t llama_encode(
struct llama_batch batch) { struct llama_batch batch) {
llama_batch_allocr batch_allocr(ctx, batch); llama_batch_allocr batch_allocr(ctx, batch);
const int ret = llama_encode_internal(*ctx, batch_allocr.batch); const int ret = llama_encode_internal(*ctx, batch_allocr.batch);
if (ret < 0) { if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret);
} }
@ -21146,7 +21146,7 @@ int32_t llama_decode(
struct llama_batch batch) { struct llama_batch batch) {
llama_batch_allocr batch_allocr(ctx, batch); llama_batch_allocr batch_allocr(ctx, batch);
const int ret = llama_decode_internal(*ctx, batch_allocr.batch); const int ret = llama_decode_internal(*ctx, batch_allocr.batch);
if (ret < 0) { if (ret != 0) {
LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret);
} }