From 491d2da02f9f1ce06b5616375f421fd0a8ccbb3c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 11 Mar 2024 14:45:55 +0200 Subject: [PATCH] llama : n_parallel -> n_seq_max --- llama.cpp | 6 +++--- llama.h | 2 +- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index 44c73df5c..33223e096 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12551,7 +12551,7 @@ struct llama_context_params llama_context_default_params() { /*.seed =*/ LLAMA_DEFAULT_SEED, /*.n_ctx =*/ 512, /*.n_batch =*/ 512, - /*.n_parallel =*/ 1, + /*.n_seq_max =*/ 1, /*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, @@ -12713,7 +12713,7 @@ struct llama_context * llama_new_context_with_model( auto & cparams = ctx->cparams; cparams.n_batch = params.n_batch; - // TODO: maybe add n_parallel here too + // TODO: maybe add n_seq_max here too cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -12780,7 +12780,7 @@ struct llama_context * llama_new_context_with_model( // Mamba only needs a constant number of KV cache cells per sequence if (model->arch == LLM_ARCH_MAMBA) { // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_parallel); + kv_size = std::max((uint32_t) 1, params.n_seq_max); // it's probably best to keep as much precision as possible for the states type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states diff --git a/llama.h b/llama.h index 55cc92fe7..446899da6 100644 --- a/llama.h +++ b/llama.h @@ -235,7 +235,7 @@ extern "C" { uint32_t seed; // RNG seed, -1 for random uint32_t n_ctx; // text context, 0 = from model uint32_t n_batch; // prompt processing maximum batch size - uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models) + uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models) uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing