diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 224005d0a..b9ea9f08e 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2816,7 +2816,7 @@ kernel void kernel_flash_attn_ext( threadgroup_barrier(mem_flags::mem_threadgroup); { - float S[Q] = { [0 ... Q-1] = 0.0h }; + float S[Q] = { [0 ... Q-1] = 0.0f }; float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; // thread indices inside the simdgroup @@ -3064,7 +3064,7 @@ kernel void kernel_flash_attn_ext( // reduce the warps sequentially for (short sg = 1; sg < nsg; ++sg) { - float S = { 0.0h }; + float S = { 0.0f }; float M = { -FLT_MAX/2 }; threadgroup_barrier(mem_flags::mem_threadgroup); @@ -3263,7 +3263,7 @@ kernel void kernel_flash_attn_ext_vec( // zero out lo for (short i = 0; i < D16/NW4; i += NW4) { - lo[i] = float4x4(0.0h); + lo[i] = float4x4(0.0f); } // zero out shared memory SH @@ -3274,7 +3274,7 @@ kernel void kernel_flash_attn_ext_vec( threadgroup_barrier(mem_flags::mem_threadgroup); { - float S = 0.0h; + float S = 0.0f; float M = -FLT_MAX/2; // thread indices inside the simdgroup