diff --git a/llama.cpp b/llama.cpp index 26802d96a..83e0c2ef1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1823,7 +1823,7 @@ struct llama_hparams { float f_logit_scale = 0.0f; bool causal_attn = true; - bool need_kq_pos = false; // currently, we need KQ_pos data for ALiBi-based models + bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; @@ -4104,7 +4104,7 @@ static void llm_load_hparams( model.ftype = ml.ftype; if (hparams.f_max_alibi_bias > 0.0f) { - hparams.need_kq_pos = true; + hparams.use_alibi = true; } hparams.rope_type = llama_rope_type(&model); @@ -6269,7 +6269,6 @@ static struct ggml_tensor * llm_build_moe_ffn( return moe_out; } -// if max_alibi_bias > 0 then apply ALiBi static struct ggml_tensor * llm_build_kqv( struct ggml_context * ctx, const llama_model & model, @@ -6359,7 +6358,7 @@ static struct ggml_tensor * llm_build_kqv( #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") - if (hparams.f_max_alibi_bias > 0.0f) { + if (hparams.use_alibi) { kq = ggml_scale(ctx, kq, kq_scale); cb(kq, "kq_scaled", il); @@ -10714,7 +10713,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (hparams.need_kq_pos) { + // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch + // this allows to process multiple sequences in parallel with ALiBi-based models + if (hparams.use_alibi) { const int64_t n_kv = kv_self.n; GGML_ASSERT(lctx.inp_KQ_pos); @@ -15116,7 +15117,7 @@ struct llama_context * llama_new_context_with_model( } } - if (cparams.flash_attn && hparams.need_kq_pos) { + if (cparams.flash_attn && hparams.use_alibi) { LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); cparams.flash_attn = false; }