diff --git a/ggml.c b/ggml.c index e88c046aa..35e067c68 100644 --- a/ggml.c +++ b/ggml.c @@ -14540,7 +14540,8 @@ static void ggml_compute_forward_flash_attn_f32( S[i] = -INFINITY; } - for (int64_t ic = 0; ic < nek1; ++ic) { + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { // k indices const int ik3 = iq3; const int ik2 = iq2 % nek2; @@ -14556,9 +14557,8 @@ static void ggml_compute_forward_flash_attn_f32( } // scale - ggml_vec_scale_f32(nek1, S, scale); + ggml_vec_scale_f32(masked_begin, S, scale); - const int64_t masked_begin = masked ? (P + iq1 + 1) : M; for (int64_t i = masked_begin; i < M; i++) { S[i] = -INFINITY; } @@ -15195,7 +15195,8 @@ static void ggml_compute_forward_flash_attn_back_f32( S[i] = -INFINITY; } - for (int64_t ic = 0; ic < nek1; ++ic) { + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t ic = 0; ic < masked_begin; ++ic) { // k indices const int ik1 = ic; @@ -15209,9 +15210,8 @@ static void ggml_compute_forward_flash_attn_back_f32( } // scale - ggml_vec_scale_f32(nek1, S, scale); + ggml_vec_scale_f32(masked_begin, S, scale); - const int64_t masked_begin = masked ? (P + iq1 + 1) : M; for (int64_t i = masked_begin; i < M; i++) { S[i] = -INFINITY; }