llama : do not limit n_batch to n_ctx with non-casual attn

This commit is contained in:
slaren 2024-03-13 14:59:32 +01:00
parent cda49d3828
commit 015e1bfe64

View file

@ -12773,7 +12773,8 @@ struct llama_context * llama_new_context_with_model(
cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base;
cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale;
cparams.n_batch = std::min(cparams.n_ctx, params.n_batch);
// with causal attention, the batch size is limited by the context size
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);