int64 dst
This commit is contained in:
parent
3ab47eb746
commit
4abeb60a1a
1 changed files with 2 additions and 2 deletions
|
@ -3234,10 +3234,10 @@ kernel void kernel_flash_attn_ext(
|
||||||
// final rescale with 1/S and store to global memory
|
// final rescale with 1/S and store to global memory
|
||||||
if (sgitg == 0) {
|
if (sgitg == 0) {
|
||||||
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||||
const float S = ss[j*TS + 0];
|
const half S = ss[j*TS + 0];
|
||||||
|
|
||||||
for (short i = tiisg; i < D4; i += NW) {
|
for (short i = tiisg; i < D4; i += NW) {
|
||||||
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*T4 + i]/S;
|
dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*T4 + i]/S;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue