adjust fragmentation fix

This commit is contained in:
Concedo 2023-12-02 15:59:08 +08:00
parent 1c422f45cb
commit 12f66eaa1d

View file

@ -74,6 +74,7 @@ static llama_v3_context * llama_ctx_v3;
static llama_context * llama_ctx_v4; static llama_context * llama_ctx_v4;
static gpt_params params; static gpt_params params;
static int max_context_limit_at_load = 0;
static int n_past = 0; static int n_past = 0;
static int n_threads = 4; static int n_threads = 4;
static int n_blasthreads = 4; static int n_blasthreads = 4;
@ -690,6 +691,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
} }
params.n_ctx = clamped_max_context_length; params.n_ctx = clamped_max_context_length;
max_context_limit_at_load = clamped_max_context_length;
neox_ctx_v2.hparams.n_ctx = neox_ctx_v3.hparams.n_ctx neox_ctx_v2.hparams.n_ctx = neox_ctx_v3.hparams.n_ctx
= gptj_ctx_v1.hparams.n_ctx = gptj_ctx_v2.hparams.n_ctx = gptj_ctx_v3.hparams.n_ctx = gptj_ctx_v1.hparams.n_ctx = gptj_ctx_v2.hparams.n_ctx = gptj_ctx_v3.hparams.n_ctx
@ -1446,6 +1448,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
bool stream_sse = inputs.stream_sse; bool stream_sse = inputs.stream_sse;
if(params.n_ctx >= 256 && useContextShift && (file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON)) if(params.n_ctx >= 256 && useContextShift && (file_format == FileFormat::GGUF_LLAMA || file_format==FileFormat::GGUF_FALCON))
{
if(params.n_ctx + extra_context_handle_fragmentation >= max_context_limit_at_load)
{ {
params.n_ctx -= extra_context_handle_fragmentation; //add some additional buffer to handle KV fragmentation params.n_ctx -= extra_context_handle_fragmentation; //add some additional buffer to handle KV fragmentation
if(debugmode==1) if(debugmode==1)
@ -1453,6 +1457,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
printf("\nTrue max context permitted: %d\n",params.n_ctx); printf("\nTrue max context permitted: %d\n",params.n_ctx);
} }
} }
}
bool allow_regular_prints = (debugmode!=-1 && !inputs.quiet) || debugmode >= 1; bool allow_regular_prints = (debugmode!=-1 && !inputs.quiet) || debugmode >= 1;