diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 512a06d74..2e447c928 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2879,7 +2879,7 @@ kernel void kernel_flash_attn_ext( if (iq1 + j < ne01) { sq4[j*T4 + i] = (q4_t) q4[i]; } else { - sq4[j*T4 + i] = (q4_t) (float4) 0.0f; + sq4[j*T4 + i] = (q4_t) 0.0f; } } } @@ -2892,7 +2892,7 @@ kernel void kernel_flash_attn_ext( // zero out shared memory SH for (short j = 0; j < Q; ++j) { for (short i = tiisg; i < SH; i += NW) { - ss[j*TS + i] = (s_t) 0.0f; + ss[j*TS + i] = 0.0f; } } @@ -3044,7 +3044,7 @@ kernel void kernel_flash_attn_ext( const float m = M[j]; // scale and apply the logitcap / mask - float s = ((float)(ss[j*TS + tiisg]))*scale; + float s = ss[j*TS + tiisg]*scale; if (logit_softcap != 0.0f) { s = logit_softcap*precise::tanh(s); @@ -3064,12 +3064,12 @@ kernel void kernel_flash_attn_ext( S[j] = S[j]*ms[j] + simd_sum(vs); // the P matrix from the paper (Q rows, C columns) - ss[j*TS + tiisg] = (s_t) vs; + ss[j*TS + tiisg] = vs; } // create a QxQ diagonal matrix for rescaling the output if (tiisg < Q) { - ss[tiisg*TS + C + tiisg] = (s_t) ms[tiisg]; + ss[tiisg*TS + C + tiisg] = ms[tiisg]; } } @@ -3157,8 +3157,8 @@ kernel void kernel_flash_attn_ext( // these are needed for reducing the results from the simdgroups (reuse the ss buffer) for (short j = 0; j < Q; ++j) { if (tiisg == 0) { - ss[j*TS + 0] = (s_t) S[j]; - ss[j*TS + 1] = (s_t) M[j]; + ss[j*TS + 0] = S[j]; + ss[j*TS + 1] = M[j]; } } } @@ -3196,18 +3196,16 @@ kernel void kernel_flash_attn_ext( S = S0*ms0 + S1*ms1; if (tiisg == 0) { - ss[j*TS + 0] = (s_t) S; - ss[j*TS + 1] = (s_t) M; + ss[j*TS + 0] = S; + ss[j*TS + 1] = M; - ss[j*TS + C + j ] = (s_t) ms0; - ss[j*TS + C + j + sg*SH] = (s_t) ms1; + ss[j*TS + C + j ] = ms0; + ss[j*TS + C + j + sg*SH] = ms1; } } // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 { - o8x8_t t; - s8x8_t ms0; s8x8_t ms1; @@ -3215,6 +3213,8 @@ kernel void kernel_flash_attn_ext( simdgroup_load(ms1, ss + C + sg*SH, TS, 0, false); for (short i = 0; i < D8; ++i) { + o8x8_t t; + simdgroup_load (t, so + i*8, T, 0, false); simdgroup_multiply(t, ms1, t);