match to metal impl

This commit is contained in:
FSSRepo 2024-01-24 16:45:30 -05:00
parent 972c2adc15
commit 0fc36d872c

View file

@ -6225,7 +6225,7 @@ static __global__ void flash_attn_ext_f16(
s += ss[hiiw*tph + i];
}
s = s * scale_h + mv; // s*scale + mv
s = s*scale_h + mv; // s*scale + mv
half2 m = M;
@ -6234,7 +6234,7 @@ static __global__ void flash_attn_ext_f16(
half2 ms = h2exp(m - M);
half2 vs = h2exp(s - M);
S = S * ms + vs;
S = S*ms + vs;
ss[2*hiiw + 0] = ms;
ss[2*hiiw + 1] = vs;
@ -6247,7 +6247,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int i = 0; i < D2/tph; ++i) {
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms + pv2[tph*i + tiih] * vs;
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms + pv2[tph*i + tiih]*vs;
}
}
@ -6272,7 +6272,7 @@ static __global__ void flash_attn_ext_f16(
half2 ms0 = h2exp(M0 - M);
half2 ms1 = h2exp(M1 - M);
S = S0 * ms0 + S1 * ms1;
S = S0*ms0 + S1*ms1;
if (tiih == 0) {
ss[2*hiiw + 0] = S;
@ -6280,7 +6280,7 @@ static __global__ void flash_attn_ext_f16(
}
for (int i = 0; i < D2/tph; ++i) {
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih] * ms1;
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih]*ms1;
}
}