llama : force disable flash attention for incompatible models
This commit is contained in:
parent
cb76d747d1
commit
c11d05fec0
1 changed files with 13 additions and 1 deletions
14
llama.cpp
14
llama.cpp
|
@ -1823,7 +1823,7 @@ struct llama_hparams {
|
|||
float f_logit_scale = 0.0f;
|
||||
|
||||
bool causal_attn = true;
|
||||
bool need_kq_pos = false;
|
||||
bool need_kq_pos = false; // currently, we need KQ_pos data for ALiBi-based models
|
||||
|
||||
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_NONE;
|
||||
enum llama_rope_type rope_type = LLAMA_ROPE_TYPE_NONE;
|
||||
|
@ -6311,6 +6311,8 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
GGML_UNUSED(model);
|
||||
GGML_UNUSED(n_ctx);
|
||||
|
||||
// note: if this assert triggers, then some check has failed earlier
|
||||
// the idea is to detect during context creation that ALiBi would be used and disable Flash Attention
|
||||
GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention");
|
||||
|
||||
// split cached v into n_head heads (not transposed)
|
||||
|
@ -15114,6 +15116,16 @@ struct llama_context * llama_new_context_with_model(
|
|||
}
|
||||
}
|
||||
|
||||
if (cparams.flash_attn && hparams.need_kq_pos) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__);
|
||||
cparams.flash_attn = false;
|
||||
}
|
||||
|
||||
if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
|
||||
cparams.flash_attn = false;
|
||||
}
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue