CPU/CUDA: Gemma 2 FlashAttention support (#8542)
* CPU/CUDA: Gemma 2 FlashAttention support * apply logit_softcap to scale in kernel * disable logit softcapping tests on Metal * remove metal check
This commit is contained in:
parent
8f824ffe8e
commit
e11bd856d5
12 changed files with 319 additions and 79 deletions
|
@ -8874,7 +8874,8 @@ static struct ggml_tensor * llm_build_kqv(
|
|||
0);
|
||||
cb(v, "v", il);
|
||||
|
||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
|
||||
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias,
|
||||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
|
||||
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
|
||||
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
|
||||
|
@ -17533,12 +17534,6 @@ struct llama_context * llama_new_context_with_model(
|
|||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
if (params.flash_attn && model->hparams.attn_soft_cap) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn is not compatible with attn_soft_cap - forcing off\n", __func__);
|
||||
params.flash_attn = false;
|
||||
}
|
||||
|
||||
|
||||
if (params.flash_attn && model->hparams.n_embd_head_k != model->hparams.n_embd_head_v) {
|
||||
LLAMA_LOG_WARN("%s: flash_attn requires n_embd_head_k == n_embd_head_v - forcing off\n", __func__);
|
||||
params.flash_attn = false;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue