added cli arg to disable context shift
This commit is contained in:
parent
822b6322de
commit
173d4bb336
4 changed files with 39 additions and 26 deletions
|
@ -9,8 +9,8 @@ repos:
|
||||||
- id: end-of-file-fixer
|
- id: end-of-file-fixer
|
||||||
- id: check-yaml
|
- id: check-yaml
|
||||||
- id: check-added-large-files
|
- id: check-added-large-files
|
||||||
- repo: https://github.com/PyCQA/flake8
|
# - repo: https://github.com/PyCQA/flake8
|
||||||
rev: 7.0.0
|
# rev: 7.0.0
|
||||||
hooks:
|
# hooks:
|
||||||
- id: flake8
|
# - id: flake8
|
||||||
additional_dependencies: [flake8-no-print]
|
# additional_dependencies: [flake8-no-print]
|
||||||
|
|
|
@ -697,6 +697,13 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
params.n_keep = value;
|
params.n_keep = value;
|
||||||
}
|
}
|
||||||
));
|
));
|
||||||
|
add_opt(llama_arg(
|
||||||
|
{"--no-context-shift"},
|
||||||
|
format("disables context shift on inifinite text generation (default: %s)", params.ctx_shift ? "disabled" : "enabled"),
|
||||||
|
[](gpt_params & params) {
|
||||||
|
params.ctx_shift = false;
|
||||||
|
}
|
||||||
|
));
|
||||||
add_opt(llama_arg(
|
add_opt(llama_arg(
|
||||||
{"--chunks"}, "N",
|
{"--chunks"}, "N",
|
||||||
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
|
format("max number of chunks to process (default: %d, -1 = all)", params.n_chunks),
|
||||||
|
@ -1992,4 +1999,3 @@ gpt_params_context gpt_params_parser_init(gpt_params & params, llama_example ex,
|
||||||
|
|
||||||
return ctx_arg;
|
return ctx_arg;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -248,6 +248,7 @@ struct gpt_params {
|
||||||
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
bool cont_batching = true; // insert new sequences for decoding on-the-fly
|
||||||
bool flash_attn = false; // flash attention
|
bool flash_attn = false; // flash attention
|
||||||
bool no_perf = false; // disable performance metrics
|
bool no_perf = false; // disable performance metrics
|
||||||
|
bool ctx_shift = true; // context shift on inifinite text generation
|
||||||
|
|
||||||
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
bool input_prefix_bos = false; // prefix BOS to user inputs, preceding input_prefix
|
||||||
bool logits_all = false; // return logits for all tokens in the batch
|
bool logits_all = false; // return logits for all tokens in the batch
|
||||||
|
|
|
@ -579,29 +579,35 @@ int main(int argc, char ** argv) {
|
||||||
// if we run out of context:
|
// if we run out of context:
|
||||||
// - take the n_keep first tokens from the original prompt (via n_past)
|
// - take the n_keep first tokens from the original prompt (via n_past)
|
||||||
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
|
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
|
||||||
|
|
||||||
if (n_past + (int) embd.size() >= n_ctx) {
|
if (n_past + (int) embd.size() >= n_ctx) {
|
||||||
if (params.n_predict == -2) {
|
if(!params.ctx_shift){
|
||||||
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
LOG_TEE("\n\n%s: context full and context shift is disabled => stopping\n", __func__);
|
||||||
break;
|
break;
|
||||||
|
} else {
|
||||||
|
if (params.n_predict == -2) {
|
||||||
|
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int n_left = n_past - params.n_keep;
|
||||||
|
const int n_discard = n_left/2;
|
||||||
|
|
||||||
|
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||||
|
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||||
|
|
||||||
|
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
||||||
|
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
||||||
|
|
||||||
|
n_past -= n_discard;
|
||||||
|
|
||||||
|
LOG("after swap: n_past = %d\n", n_past);
|
||||||
|
|
||||||
|
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
||||||
|
|
||||||
|
LOG("clear session path\n");
|
||||||
|
path_session.clear();
|
||||||
}
|
}
|
||||||
|
|
||||||
const int n_left = n_past - params.n_keep;
|
|
||||||
const int n_discard = n_left/2;
|
|
||||||
|
|
||||||
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
|
||||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
|
||||||
|
|
||||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep , params.n_keep + n_discard);
|
|
||||||
llama_kv_cache_seq_add(ctx, 0, params.n_keep + n_discard, n_past, -n_discard);
|
|
||||||
|
|
||||||
n_past -= n_discard;
|
|
||||||
|
|
||||||
LOG("after swap: n_past = %d\n", n_past);
|
|
||||||
|
|
||||||
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
|
||||||
|
|
||||||
LOG("clear session path\n");
|
|
||||||
path_session.clear();
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// context extension via Self-Extend
|
// context extension via Self-Extend
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue