use precise::tanh
This commit is contained in:
parent
054203ae69
commit
edc2e27383
1 changed files with 3 additions and 3 deletions
|
@ -2149,8 +2149,8 @@ kernel void kernel_flash_attn_ext_f16(
|
|||
ss[8*cc + ty*TF + 2*tx + 1] *= scale;
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*tanh(ss[8*cc + ty*TF + 2*tx + 0]);
|
||||
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*tanh(ss[8*cc + ty*TF + 2*tx + 1]);
|
||||
ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
|
||||
ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
|
||||
}
|
||||
|
||||
if (mask != q) {
|
||||
|
@ -2490,7 +2490,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
mqk *= scale;
|
||||
|
||||
if (logit_softcap != 0.0f) {
|
||||
mqk = logit_softcap*tanh(mqk);
|
||||
mqk = logit_softcap*precise::tanh(mqk);
|
||||
}
|
||||
|
||||
mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue