exclude some more known zero values from computations in flash_attn_f32 & flash_attn_back_f32

This commit is contained in:
xaedes 2023-09-14 22:18:20 +02:00
parent d88dae2980
commit 76804fab1d
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

12
ggml.c
View file

@ -14540,7 +14540,8 @@ static void ggml_compute_forward_flash_attn_f32(
S[i] = -INFINITY;
}
for (int64_t ic = 0; ic < nek1; ++ic) {
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
for (int64_t ic = 0; ic < masked_begin; ++ic) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2 % nek2;
@ -14556,9 +14557,8 @@ static void ggml_compute_forward_flash_attn_f32(
}
// scale
ggml_vec_scale_f32(nek1, S, scale);
ggml_vec_scale_f32(masked_begin, S, scale);
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
for (int64_t i = masked_begin; i < M; i++) {
S[i] = -INFINITY;
}
@ -15195,7 +15195,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
S[i] = -INFINITY;
}
for (int64_t ic = 0; ic < nek1; ++ic) {
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
for (int64_t ic = 0; ic < masked_begin; ++ic) {
// k indices
const int ik1 = ic;
@ -15209,9 +15210,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
}
// scale
ggml_vec_scale_f32(nek1, S, scale);
ggml_vec_scale_f32(masked_begin, S, scale);
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
for (int64_t i = masked_begin; i < M; i++) {
S[i] = -INFINITY;
}