From fd7d5e870d7f09308a07f835f3b73c6a166a3ebb Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 3 Nov 2024 10:02:53 +0200 Subject: [PATCH] metal : use the unrolled loop variable --- ggml/src/ggml-metal.metal | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 6fc114e29..fed076f7f 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2780,7 +2780,7 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short ii = 0; ii < D4; ii += NW) { short i = ii + tiisg; - mq[i/NW] = (float4) sq4[i]; + mq[ii/NW] = (float4) sq4[i]; } // pointer to the mask @@ -2812,7 +2812,7 @@ kernel void kernel_flash_attn_ext_vec_f16( mk[2] = (float4) pk4[i + 2*(nb11/8)]; mk[3] = (float4) pk4[i + 3*(nb11/8)]; - mqk += (float4) (mq[i/NW] * mk); + mqk += (float4) (mq[ii/NW] * mk); } // reduce the results from the threads in the simdgroup @@ -2858,7 +2858,7 @@ kernel void kernel_flash_attn_ext_vec_f16( #pragma unroll for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - lo[i/NW] *= ms; + lo[ii/NW] *= ms; } } @@ -2872,10 +2872,10 @@ kernel void kernel_flash_attn_ext_vec_f16( for (short ii = 0; ii < D4; ii += NW) { const short i = ii + tiisg; - lo[i/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; - lo[i/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; - lo[i/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; - lo[i/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; + lo[ii/NW] += pv4[i + 0*(nb21/8)] * ss[4*cc + 0]; + lo[ii/NW] += pv4[i + 1*(nb21/8)] * ss[4*cc + 1]; + lo[ii/NW] += pv4[i + 2*(nb21/8)] * ss[4*cc + 2]; + lo[ii/NW] += pv4[i + 3*(nb21/8)] * ss[4*cc + 3]; } } }