llama : gptneox arch use F32 attn prec

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-05-23 15:01:40 +03:00
parent 907c135a94
commit d2bae45546
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -6718,7 +6718,7 @@ static struct ggml_tensor * llm_build_kqv(
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias); cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32); ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
} }
@ -6727,7 +6727,7 @@ static struct ggml_tensor * llm_build_kqv(
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
cb(kq, "kq", il); cb(kq, "kq", il);
if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) { if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3 || model.arch == LLM_ARCH_GPTNEOX) {
// for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs // for this arch, we need to perform the KQ multiplication with F32 precision, otherwise we get NaNs
// ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847 // ref: https://github.com/ggerganov/llama.cpp/pull/4490#issuecomment-1859055847
ggml_mul_mat_set_prec(kq, GGML_PREC_F32); ggml_mul_mat_set_prec(kq, GGML_PREC_F32);