add conditional compilation of using F16 exp in flash attention

uncomment `// #define GGML_FLASH_ATTN_EXP_FP16` to enable usage of f16 exp in flash attention
This commit is contained in:
xaedes 2023-07-03 18:45:18 +02:00
parent 1065c3b7b9
commit dbbc263313
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

11
ggml.c
View file

@ -124,6 +124,7 @@ typedef void * thread_ret_t;
#define GGML_GELU_QUICK_FP16
#define GGML_SILU_FP16
// #define GGML_CROSS_ENTROPY_EXP_FP16
// #define GGML_FLASH_ATTN_EXP_FP16
#define GGML_SOFT_MAX_UNROLL 4
#define GGML_VEC_DOT_UNROLL 2
@ -13111,10 +13112,13 @@ static void ggml_compute_forward_flash_attn_f32(
if (SS[j] == -INFINITY) {
SS[j] = 0.0f;
} else {
// const float val = expf(SS[j] - max);
#ifndef GGML_FLASH_ATTN_EXP_FP16
const float val = expf(SS[j] - max);
#else
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
#endif
sump[j] += (ggml_float)val;
SS[j] = val;
}
@ -13703,10 +13707,13 @@ static void ggml_compute_forward_flash_attn_back_f32(
if (SR[j] == -INFINITY) {
SW[j] = 0.0f;
} else {
// const float val = expf(SR[j] - max);
#ifndef GGML_FLASH_ATTN_EXP_FP16
const float val = expf(SR[j] - max);
#else
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
#endif
sump[j] += (ggml_float)val;
SW[j] = val;
}