llama : replace bool need_kq_pos with use_alibi

This commit is contained in:
Georgi Gerganov 2024-04-23 17:15:13 +03:00
parent 3864eea4cb
commit 78d363b0d4
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -1823,7 +1823,7 @@ struct llama_hparams {
float f_logit_scale = 0.0f; float f_logit_scale = 0.0f;
bool causal_attn = true; bool causal_attn = true;
bool need_kq_pos = false; // currently, we need KQ_pos data for ALiBi-based models bool use_alibi = false; // currently, we need KQ_pos data for ALiBi-based models
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE; enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@ -4104,7 +4104,7 @@ static void llm_load_hparams(
model.ftype = ml.ftype; model.ftype = ml.ftype;
if (hparams.f_max_alibi_bias > 0.0f) { if (hparams.f_max_alibi_bias > 0.0f) {
hparams.need_kq_pos = true; hparams.use_alibi = true;
} }
hparams.rope_type = llama_rope_type(&model); hparams.rope_type = llama_rope_type(&model);
@ -6269,7 +6269,6 @@ static struct ggml_tensor * llm_build_moe_ffn(
return moe_out; return moe_out;
} }
// if max_alibi_bias > 0 then apply ALiBi
static struct ggml_tensor * llm_build_kqv( static struct ggml_tensor * llm_build_kqv(
struct ggml_context * ctx, struct ggml_context * ctx,
const llama_model & model, const llama_model & model,
@ -6359,7 +6358,7 @@ static struct ggml_tensor * llm_build_kqv(
#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute") #pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute")
#pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024") #pragma message(" Falling back to ggml_alibi(). Will become an error in Mar 2024")
#pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488") #pragma message("ref: https://github.com/ggerganov/llama.cpp/pull/5488")
if (hparams.f_max_alibi_bias > 0.0f) { if (hparams.use_alibi) {
kq = ggml_scale(ctx, kq, kq_scale); kq = ggml_scale(ctx, kq, kq_scale);
cb(kq, "kq_scaled", il); cb(kq, "kq_scaled", il);
@ -10714,7 +10713,9 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
} }
} }
if (hparams.need_kq_pos) { // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch
// this allows to process multiple sequences in parallel with ALiBi-based models
if (hparams.use_alibi) {
const int64_t n_kv = kv_self.n; const int64_t n_kv = kv_self.n;
GGML_ASSERT(lctx.inp_KQ_pos); GGML_ASSERT(lctx.inp_KQ_pos);
@ -15116,7 +15117,7 @@ struct llama_context * llama_new_context_with_model(
} }
} }
if (cparams.flash_attn && hparams.need_kq_pos) { if (cparams.flash_attn && hparams.use_alibi) {
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__); LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__);
cparams.flash_attn = false; cparams.flash_attn = false;
} }