metal : use the unrolled loop variable
This commit is contained in:
parent
40e717263e
commit
fd7d5e870d
1 changed files with 7 additions and 7 deletions
|
@ -2780,7 +2780,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
short i = ii + tiisg;
|
||||
mq[i/NW] = (float4) sq4[i];
|
||||
mq[ii/NW] = (float4) sq4[i];
|
||||
}
|
||||
|
||||
// pointer to the mask
|
||||
|
@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
mk[2] = (float4) pk4[i + 2*(nb11/8)];
|
||||
mk[3] = (float4) pk4[i + 3*(nb11/8)];
|
||||
|
||||
mqk += (float4) (mq[i/NW] * mk);
|
||||
mqk += (float4) (mq[ii/NW] * mk);
|
||||
}
|
||||
|
||||
// reduce the results from the threads in the simdgroup
|
||||
|
@ -2858,7 +2858,7 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
#pragma unroll
|
||||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
const short i = ii + tiisg;
|
||||
lo[i/NW] *= ms;
|
||||
lo[ii/NW] *= ms;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2872,10 +2872,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
|
|||
for (short ii = 0; ii < D4; ii += NW) {
|
||||
const short i = ii + tiisg;
|
||||
|
||||
lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
||||
lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
||||
lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
||||
lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
||||
lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0];
|
||||
lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1];
|
||||
lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2];
|
||||
lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue