From 4400153348693db3ef008dbf73c78a2a3ae6e39d Mon Sep 17 00:00:00 2001 From: slaren Date: Wed, 13 Mar 2024 02:55:29 +0100 Subject: [PATCH] llama : limit n_batch and n_ubatch to n_ctx during context creation --- llama.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index a6388489d..b3161b521 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12757,8 +12757,6 @@ struct llama_context * llama_new_context_with_model( auto & cparams = ctx->cparams; // TODO: maybe add n_seq_max here too - cparams.n_batch = params.n_batch; - cparams.n_ubatch = params.n_ubatch == 0 ? params.n_batch : params.n_ubatch; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -12774,6 +12772,10 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; + cparams.n_batch = std::min(cparams.n_ctx, params.n_batch); + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + + cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : hparams.n_ctx_train;