metal : float-correctness

This commit is contained in:
Georgi Gerganov 2024-11-05 09:24:06 +02:00
parent d805404e2d
commit 73f378df82
No known key found for this signature in database
GPG key ID: BF970631944C16B7

View file

@ -2816,7 +2816,7 @@ kernel void kernel_flash_attn_ext(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
{ {
float S[Q] = { [0 ... Q-1] = 0.0h }; float S[Q] = { [0 ... Q-1] = 0.0f };
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 }; float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
// thread indices inside the simdgroup // thread indices inside the simdgroup
@ -3064,7 +3064,7 @@ kernel void kernel_flash_attn_ext(
// reduce the warps sequentially // reduce the warps sequentially
for (short sg = 1; sg < nsg; ++sg) { for (short sg = 1; sg < nsg; ++sg) {
float S = { 0.0h }; float S = { 0.0f };
float M = { -FLT_MAX/2 }; float M = { -FLT_MAX/2 };
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
@ -3263,7 +3263,7 @@ kernel void kernel_flash_attn_ext_vec(
// zero out lo // zero out lo
for (short i = 0; i < D16/NW4; i += NW4) { for (short i = 0; i < D16/NW4; i += NW4) {
lo[i] = float4x4(0.0h); lo[i] = float4x4(0.0f);
} }
// zero out shared memory SH // zero out shared memory SH
@ -3274,7 +3274,7 @@ kernel void kernel_flash_attn_ext_vec(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
{ {
float S = 0.0h; float S = 0.0f;
float M = -FLT_MAX/2; float M = -FLT_MAX/2;
// thread indices inside the simdgroup // thread indices inside the simdgroup