diff --git a/llama.cpp b/llama.cpp index a6388489d..b3161b521 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12757,8 +12757,6 @@ struct llama_context * llama_new_context_with_model( auto & cparams = ctx->cparams; // TODO: maybe add n_seq_max here too - cparams.n_batch = params.n_batch; - cparams.n_ubatch = params.n_ubatch == 0 ? params.n_batch : params.n_ubatch; cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; cparams.yarn_ext_factor = params.yarn_ext_factor; @@ -12774,6 +12772,10 @@ struct llama_context * llama_new_context_with_model( cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; + cparams.n_batch = std::min(cparams.n_ctx, params.n_batch); + cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch); + + cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx : hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx : hparams.n_ctx_train;