From dbbc2633137f15205a80466451a0ebe5ba8baf2f Mon Sep 17 00:00:00 2001 From: xaedes Date: Mon, 3 Jul 2023 18:45:18 +0200 Subject: [PATCH] add conditional compilation of using F16 exp in flash attention uncomment `// #define GGML_FLASH_ATTN_EXP_FP16` to enable usage of f16 exp in flash attention --- ggml.c | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 2138cb8bc..53f2c4254 100644 --- a/ggml.c +++ b/ggml.c @@ -124,6 +124,7 @@ typedef void * thread_ret_t; #define GGML_GELU_QUICK_FP16 #define GGML_SILU_FP16 // #define GGML_CROSS_ENTROPY_EXP_FP16 +// #define GGML_FLASH_ATTN_EXP_FP16 #define GGML_SOFT_MAX_UNROLL 4 #define GGML_VEC_DOT_UNROLL 2 @@ -13111,10 +13112,13 @@ static void ggml_compute_forward_flash_attn_f32( if (SS[j] == -INFINITY) { SS[j] = 0.0f; } else { - // const float val = expf(SS[j] - max); +#ifndef GGML_FLASH_ATTN_EXP_FP16 + const float val = expf(SS[j] - max); +#else ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); +#endif sump[j] += (ggml_float)val; SS[j] = val; } @@ -13703,10 +13707,13 @@ static void ggml_compute_forward_flash_attn_back_f32( if (SR[j] == -INFINITY) { SW[j] = 0.0f; } else { - // const float val = expf(SR[j] - max); +#ifndef GGML_FLASH_ATTN_EXP_FP16 + const float val = expf(SR[j] - max); +#else ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max); memcpy(&scvt[j], &s, sizeof(uint16_t)); const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); +#endif sump[j] += (ggml_float)val; SW[j] = val; }