metal : float-correctness
This commit is contained in:
parent
d805404e2d
commit
73f378df82
1 changed files with 4 additions and 4 deletions
|
@ -2816,7 +2816,7 @@ kernel void kernel_flash_attn_ext(
|
|||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
float S[Q] = { [0 ... Q-1] = 0.0h };
|
||||
float S[Q] = { [0 ... Q-1] = 0.0f };
|
||||
float M[Q] = { [0 ... Q-1] = -FLT_MAX/2 };
|
||||
|
||||
// thread indices inside the simdgroup
|
||||
|
@ -3064,7 +3064,7 @@ kernel void kernel_flash_attn_ext(
|
|||
|
||||
// reduce the warps sequentially
|
||||
for (short sg = 1; sg < nsg; ++sg) {
|
||||
float S = { 0.0h };
|
||||
float S = { 0.0f };
|
||||
float M = { -FLT_MAX/2 };
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
@ -3263,7 +3263,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
|
||||
// zero out lo
|
||||
for (short i = 0; i < D16/NW4; i += NW4) {
|
||||
lo[i] = float4x4(0.0h);
|
||||
lo[i] = float4x4(0.0f);
|
||||
}
|
||||
|
||||
// zero out shared memory SH
|
||||
|
@ -3274,7 +3274,7 @@ kernel void kernel_flash_attn_ext_vec(
|
|||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
{
|
||||
float S = 0.0h;
|
||||
float S = 0.0f;
|
||||
float M = -FLT_MAX/2;
|
||||
|
||||
// thread indices inside the simdgroup
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue