llama : force disable flash attention for incompatible models

This commit is contained in:
Georgi Gerganov 2024-04-22 12:50:41 +03:00
parent cb76d747d1
commit c11d05fec0
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -1823,7 +1823,7 @@ struct llama_hparams {
float f_logit_scale = 0.0f;
bool causal_attn = true;
bool need_kq_pos = false;
bool need_kq_pos = false; // currently, we need KQ_pos data for ALiBi-based models
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
@ -6311,6 +6311,8 @@ static struct ggml_tensor * llm_build_kqv(
GGML_UNUSED(model);
GGML_UNUSED(n_ctx);
// note: if this assert triggers, then some check has failed earlier
// the idea is to detect during context creation that ALiBi would be used and disable Flash Attention
GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention");
// split cached v into n_head heads (not transposed)
@ -15114,6 +15116,16 @@ struct llama_context * llama_new_context_with_model(
}
}
if (cparams.flash_attn && hparams.need_kq_pos) {
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__);
cparams.flash_attn = false;
}
if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
cparams.flash_attn = false;
}
if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL);
}