add conditional compilation of using F16 exp in flash attention
uncomment `// #define GGML_FLASH_ATTN_EXP_FP16` to enable usage of f16 exp in flash attention
This commit is contained in:
parent
1065c3b7b9
commit
dbbc263313
1 changed files with 9 additions and 2 deletions
11
ggml.c
11
ggml.c
|
@ -124,6 +124,7 @@ typedef void * thread_ret_t;
|
|||
#define GGML_GELU_QUICK_FP16
|
||||
#define GGML_SILU_FP16
|
||||
// #define GGML_CROSS_ENTROPY_EXP_FP16
|
||||
// #define GGML_FLASH_ATTN_EXP_FP16
|
||||
|
||||
#define GGML_SOFT_MAX_UNROLL 4
|
||||
#define GGML_VEC_DOT_UNROLL 2
|
||||
|
@ -13111,10 +13112,13 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||
if (SS[j] == -INFINITY) {
|
||||
SS[j] = 0.0f;
|
||||
} else {
|
||||
// const float val = expf(SS[j] - max);
|
||||
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
||||
const float val = expf(SS[j] - max);
|
||||
#else
|
||||
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
||||
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
||||
#endif
|
||||
sump[j] += (ggml_float)val;
|
||||
SS[j] = val;
|
||||
}
|
||||
|
@ -13703,10 +13707,13 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
if (SR[j] == -INFINITY) {
|
||||
SW[j] = 0.0f;
|
||||
} else {
|
||||
// const float val = expf(SR[j] - max);
|
||||
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
||||
const float val = expf(SR[j] - max);
|
||||
#else
|
||||
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
||||
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
||||
#endif
|
||||
sump[j] += (ggml_float)val;
|
||||
SW[j] = val;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue