match to metal impl
This commit is contained in:
parent
972c2adc15
commit
0fc36d872c
1 changed files with 5 additions and 5 deletions
10
ggml-cuda.cu
10
ggml-cuda.cu
|
@ -6225,7 +6225,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
s += ss[hiiw*tph + i];
|
s += ss[hiiw*tph + i];
|
||||||
}
|
}
|
||||||
|
|
||||||
s = s * scale_h + mv; // s*scale + mv
|
s = s*scale_h + mv; // s*scale + mv
|
||||||
|
|
||||||
half2 m = M;
|
half2 m = M;
|
||||||
|
|
||||||
|
@ -6234,7 +6234,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
half2 ms = h2exp(m - M);
|
half2 ms = h2exp(m - M);
|
||||||
half2 vs = h2exp(s - M);
|
half2 vs = h2exp(s - M);
|
||||||
|
|
||||||
S = S * ms + vs;
|
S = S*ms + vs;
|
||||||
|
|
||||||
ss[2*hiiw + 0] = ms;
|
ss[2*hiiw + 0] = ms;
|
||||||
ss[2*hiiw + 1] = vs;
|
ss[2*hiiw + 1] = vs;
|
||||||
|
@ -6247,7 +6247,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < D2/tph; ++i) {
|
for (int i = 0; i < D2/tph; ++i) {
|
||||||
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms + pv2[tph*i + tiih] * vs;
|
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms + pv2[tph*i + tiih]*vs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6272,7 +6272,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
half2 ms0 = h2exp(M0 - M);
|
half2 ms0 = h2exp(M0 - M);
|
||||||
half2 ms1 = h2exp(M1 - M);
|
half2 ms1 = h2exp(M1 - M);
|
||||||
|
|
||||||
S = S0 * ms0 + S1 * ms1;
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
if (tiih == 0) {
|
if (tiih == 0) {
|
||||||
ss[2*hiiw + 0] = S;
|
ss[2*hiiw + 0] = S;
|
||||||
|
@ -6280,7 +6280,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int i = 0; i < D2/tph; ++i) {
|
for (int i = 0; i < D2/tph; ++i) {
|
||||||
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih] * ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih] * ms1;
|
ps2[hiiw*D2 + tph*i + tiih] = ps2[hiiw*D2 + tph*i + tiih]*ms0 + ps2[sg*(R*D + 32)/4 + hiiw*D2 + tph*i + tiih]*ms1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue