diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e9657dd88..b7ebfcc57 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -6225,7 +6225,7 @@ static __global__ void flash_attn_ext_f16( s += ss[hiiw*tph + i]; } - s = s * scale_h + mv; // s*scale + mv + s = s*scale_h + mv; // s*scale + mv half2 m = M; @@ -6234,7 +6234,7 @@ static __global__ void flash_attn_ext_f16( half2 ms = h2exp(m - M); half2 vs = h2exp(s - M); - S = S * ms + vs; + S = S*ms + vs; ss[2*hiiw + 0] = ms; ss[2*hiiw + 1] = vs; @@ -6247,7 +6247,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int i = 0; i < D2/tph; ++i) { - ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms + pv2[tph*i + tiih] * vs; + ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms + pv2[tph*i + tiih]*vs; } } @@ -6272,7 +6272,7 @@ static __global__ void flash_attn_ext_f16( half2 ms0 = h2exp(M0 - M); half2 ms1 = h2exp(M1 - M); - S = S0 * ms0 + S1 * ms1; + S = S0*ms0 + S1*ms1; if (tiih == 0) { ss[2*hiiw + 0] = S; @@ -6280,7 +6280,7 @@ static __global__ void flash_attn_ext_f16( } for (int i = 0; i < D2/tph; ++i) { - ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih] * ms1; + ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih]*ms1; } }