From 8ad42a1102ea5d37fb06bce78ad6fe03ea8515ce Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Fri, 14 Apr 2023 21:30:26 +0800 Subject: [PATCH] read from inputs --- gpttype_adapter.cpp | 4 +++- llama_adapter.cpp | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index c025db5a5..83974e761 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -29,6 +29,7 @@ static gpt_params params; static int n_past = 0; static int n_threads = 4; static int n_batch = 8; +static bool useSmartContext = false; static std::string modelname; static std::vector last_n_tokens; static std::vector current_context_tokens; @@ -51,6 +52,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in n_threads = params.n_threads = inputs.threads; n_batch = params.n_batch = inputs.batch_size; modelname = params.model = inputs.model_filename; + useSmartContext = inputs.use_smartcontext; params.memory_f16 = inputs.f16_kv; params.n_ctx = inputs.max_context_length; model_v1.hparams.n_ctx = model_v2.hparams.n_ctx = model_gpt2_v1.hparams.n_ctx = model_gpt2_v2.hparams.n_ctx = params.n_ctx; @@ -196,7 +198,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); n_past = 0; - ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, true); + ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext); //if using BLAS and prompt is big enough, switch to single thread and use a huge batch // bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2); diff --git a/llama_adapter.cpp b/llama_adapter.cpp index 20cdd4fcb..ba14fe250 100644 --- a/llama_adapter.cpp +++ b/llama_adapter.cpp @@ -27,6 +27,7 @@ static gpt_params params; static int n_past = 0; static int n_threads = 4; static int n_batch = 8; +static bool useSmartContext = false; static std::string modelname; static llama_context *ctx; static std::vector last_n_tokens; @@ -42,6 +43,7 @@ bool llama_load_model(const load_model_inputs inputs, FileFormat in_file_format) n_threads = inputs.threads; n_batch = inputs.batch_size; modelname = inputs.model_filename; + useSmartContext = inputs.use_smartcontext; ctx_params.n_ctx = inputs.max_context_length; ctx_params.n_parts = -1;//inputs.n_parts_overwrite; @@ -133,7 +135,7 @@ generation_outputs llama_generate(const generation_inputs inputs, generation_out std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); n_past = 0; - ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, true); + ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext); //if using BLAS and prompt is big enough, switch to single thread and use a huge batch bool blasmode = (embd_inp.size() >= 32 && ggml_cpu_has_blas());