read from inputs

This commit is contained in:
Concedo 2023-04-14 21:30:26 +08:00
parent adb4df78d6
commit 8ad42a1102
2 changed files with 6 additions and 2 deletions

View file

@ -29,6 +29,7 @@ static gpt_params params;
static int n_past = 0;
static int n_threads = 4;
static int n_batch = 8;
static bool useSmartContext = false;
static std::string modelname;
static std::vector<gpt_vocab::id> last_n_tokens;
static std::vector<gpt_vocab::id> current_context_tokens;
@ -51,6 +52,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in
n_threads = params.n_threads = inputs.threads;
n_batch = params.n_batch = inputs.batch_size;
modelname = params.model = inputs.model_filename;
useSmartContext = inputs.use_smartcontext;
params.memory_f16 = inputs.f16_kv;
params.n_ctx = inputs.max_context_length;
model_v1.hparams.n_ctx = model_v2.hparams.n_ctx = model_gpt2_v1.hparams.n_ctx = model_gpt2_v2.hparams.n_ctx = params.n_ctx;
@ -196,7 +198,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
n_past = 0;
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, true);
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext);
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
// bool approved_format = (file_format!=FileFormat::GPT2_1 && file_format!=FileFormat::GPTJ_1 && file_format!=FileFormat::GPTJ_2);

View file

@ -27,6 +27,7 @@ static gpt_params params;
static int n_past = 0;
static int n_threads = 4;
static int n_batch = 8;
static bool useSmartContext = false;
static std::string modelname;
static llama_context *ctx;
static std::vector<llama_token> last_n_tokens;
@ -42,6 +43,7 @@ bool llama_load_model(const load_model_inputs inputs, FileFormat in_file_format)
n_threads = inputs.threads;
n_batch = inputs.batch_size;
modelname = inputs.model_filename;
useSmartContext = inputs.use_smartcontext;
ctx_params.n_ctx = inputs.max_context_length;
ctx_params.n_parts = -1;//inputs.n_parts_overwrite;
@ -133,7 +135,7 @@ generation_outputs llama_generate(const generation_inputs inputs, generation_out
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
n_past = 0;
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, true);
ContextFastForward(current_context_tokens, embd_inp, n_past, last_n_tokens, nctx, smartcontext, useSmartContext);
//if using BLAS and prompt is big enough, switch to single thread and use a huge batch
bool blasmode = (embd_inp.size() >= 32 && ggml_cpu_has_blas());