llama : n_parallel -> n_seq_max

This commit is contained in:
Georgi Gerganov 2024-03-11 14:45:55 +02:00
parent 32daccd755
commit 491d2da02f
No known key found for this signature in database
GPG key ID: BF970631944C16B7
2 changed files with 4 additions and 4 deletions

View file

@ -12551,7 +12551,7 @@ struct llama_context_params llama_context_default_params() {
/*.seed =*/ LLAMA_DEFAULT_SEED,
/*.n_ctx =*/ 512,
/*.n_batch =*/ 512,
/*.n_parallel =*/ 1,
/*.n_seq_max =*/ 1,
/*.n_threads =*/ GGML_DEFAULT_N_THREADS, // TODO: better default
/*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS,
/*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED,
@ -12713,7 +12713,7 @@ struct llama_context * llama_new_context_with_model(
auto & cparams = ctx->cparams;
cparams.n_batch = params.n_batch;
// TODO: maybe add n_parallel here too
// TODO: maybe add n_seq_max here too
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@ -12780,7 +12780,7 @@ struct llama_context * llama_new_context_with_model(
// Mamba only needs a constant number of KV cache cells per sequence
if (model->arch == LLM_ARCH_MAMBA) {
// Mamba needs at least as many KV cells as there are sequences kept at any time
kv_size = std::max((uint32_t) 1, params.n_parallel);
kv_size = std::max((uint32_t) 1, params.n_seq_max);
// it's probably best to keep as much precision as possible for the states
type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states
type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states

View file

@ -235,7 +235,7 @@ extern "C" {
uint32_t seed; // RNG seed, -1 for random
uint32_t n_ctx; // text context, 0 = from model
uint32_t n_batch; // prompt processing maximum batch size
uint32_t n_parallel; // number of parallel sequences (i.e. distinct states for recurrent models)
uint32_t n_seq_max; // max number of sequences (i.e. distinct states for recurrent models)
uint32_t n_threads; // number of threads to use for generation
uint32_t n_threads_batch; // number of threads to use for batch processing