llama : limit n_batch and n_ubatch to n_ctx during context creation

This commit is contained in:
slaren 2024-03-13 02:55:29 +01:00
parent 255c1ec18e
commit 4400153348

View file

@ -12757,8 +12757,6 @@ struct llama_context * llama_new_context_with_model(
auto & cparams = ctx->cparams;
// TODO: maybe add n_seq_max here too
cparams.n_batch = params.n_batch;
cparams.n_ubatch = params.n_ubatch == 0 ? params.n_batch : params.n_ubatch;
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
cparams.yarn_ext_factor = params.yarn_ext_factor;
@ -12774,6 +12772,10 @@ struct llama_context * llama_new_context_with_model(
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
cparams.n_batch = std::min(cparams.n_ctx, params.n_batch);
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
hparams.n_ctx_train;