From 173d4bb336eb3be9d1ddeb817d4697f0e040e89f Mon Sep 17 00:00:00 2001 From: VJHack Date: Sat, 14 Sep 2024 11:15:51 -0500 Subject: [PATCH] added cli arg to disable context shift --- .pre-commit-config.yaml | 10 ++++----- common/arg.cpp | 8 ++++++- common/common.h | 1 + examples/main/main.cpp | 46 +++++++++++++++++++++++------------------ 4 files changed, 39 insertions(+), 26 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 91d791628..84a81bb56 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -9,8 +9,8 @@ repos: - id: end-of-file-fixer - id: check-yaml - id: check-added-large-files -- repo: https://github.com/PyCQA/flake8 - rev: 7.0.0 - hooks: - - id: flake8 - additional_dependencies: [flake8-no-print] +# - repo: https://github.com/PyCQA/flake8 +# rev: 7.0.0 +# hooks: +# - id: flake8 +# additional_dependencies: [flake8-no-print] diff --git a/common/arg.cpp b/common/arg.cpp index a1cd5830f..f5faddf12 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -697,6 +697,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, params.n_keep = value; } )); + add_opt(llama_arg( + {"--no-context-shift"}, + format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"), + [](gpt_params & params) { + params.ctx_shift = false; + } + )); add_opt(llama_arg( {"--chunks"}, "N", format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks), @@ -1992,4 +1999,3 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex, return ctx_arg; } - diff --git a/common/common.h b/common/common.h index e8025aeef..33cb00446 100644 --- a/common/common.h +++ b/common/common.h @@ -248,6 +248,7 @@ struct gpt_params { bool cont_batching = true; // insert new sequences for decoding on-the-fly bool flash_attn = false; // flash attention bool no_perf = false; // disable performance metrics + bool ctx_shift = true; // context shift on inifinite text generation bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix bool logits_all = false; // return logits for all tokens in the batch diff --git a/examples/main/main.cpp b/examples/main/main.cpp index f41be5308..836e7fcb3 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -579,29 +579,35 @@ int main(int argc, char ** argv) { // if we run out of context: // - take the n_keep first tokens from the original prompt (via n_past) // - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches + if (n_past + (int) embd.size() >= n_ctx) { - if (params.n_predict == -2) { - LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + if(!params.ctx_shift){ + LOG_TEE("\n\n%s: context full and context shift is disabled => stopping\n", __func__); break; + } else { + if (params.n_predict == -2) { + LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict); + break; + } + + const int n_left = n_past - params.n_keep; + const int n_discard = n_left/2; + + LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", + n_past, n_left, n_ctx, params.n_keep, n_discard); + + llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); + llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); + + n_past -= n_discard; + + LOG("after swap: n_past = %d\n", n_past); + + LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); + + LOG("clear session path\n"); + path_session.clear(); } - - const int n_left = n_past - params.n_keep; - const int n_discard = n_left/2; - - LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n", - n_past, n_left, n_ctx, params.n_keep, n_discard); - - llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard); - llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard); - - n_past -= n_discard; - - LOG("after swap: n_past = %d\n", n_past); - - LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str()); - - LOG("clear session path\n"); - path_session.clear(); } } else { // context extension via Self-Extend