llama : fix n_batch requirements

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-23 17:30:37 +03:00
parent 19e8982f51
commit 56657e52e5
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -15064,10 +15064,6 @@ struct llama_context * llama_new_context_with_model(
const auto & hparams = model->hparams;
auto & cparams = ctx->cparams;
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch);
cparams.n_seq_max = std::max(1u, params.n_seq_max);
cparams.n_threads = params.n_threads;
cparams.n_threads_batch = params.n_threads_batch;
@ -15090,8 +15086,16 @@ struct llama_context * llama_new_context_with_model(
// with causal attention, the batch size is limited by the context size
cparams.n_batch = hparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch;
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
// ref: https://github.com/ggerganov/llama.cpp/pull/5021
if (cparams.n_batch < GGML_KQ_MASK_PAD) {
LLAMA_LOG_WARN("%s: n_batch is less than GGML_KQ_MASK_PAD - increasing to %d\n", __func__, GGML_KQ_MASK_PAD);
cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch);
}
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :