metal : use the unrolled loop variable

This commit is contained in:
Georgi Gerganov 2024-11-03 10:02:53 +02:00
parent 40e717263e
commit fd7d5e870d
No known key found for this signature in database
GPG key ID: 449E073F9DC10735

View file

@ -2780,7 +2780,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short ii = 0; ii < D4; ii += NW) {
short i = ii + tiisg;
mq[i/NW] = (float4) sq4[i];
mq[ii/NW] = (float4) sq4[i];
}
// pointer to the mask
@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
mk[2] = (float4) pk4[i + 2*(nb11/8)];
mk[3] = (float4) pk4[i + 3*(nb11/8)];
mqk += (float4) (mq[i/NW] * mk);
mqk += (float4) (mq[ii/NW] * mk);
}
// reduce the results from the threads in the simdgroup
@ -2858,7 +2858,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
#pragma unroll
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
lo[i/NW] *= ms;
lo[ii/NW] *= ms;
}
}
@ -2872,10 +2872,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
for (short ii = 0; ii < D4; ii += NW) {
const short i = ii + tiisg;
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
}
}
}