exclude some more known zero values from computations in flash_attn_f32 & flash_attn_back_f32
This commit is contained in:
parent
d88dae2980
commit
76804fab1d
1 changed files with 6 additions and 6 deletions
12
ggml.c
12
ggml.c
|
@ -14540,7 +14540,8 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||
S[i] = -INFINITY;
|
||||
}
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
||||
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
||||
// k indices
|
||||
const int ik3 = iq3;
|
||||
const int ik2 = iq2 % nek2;
|
||||
|
@ -14556,9 +14557,8 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||
}
|
||||
|
||||
// scale
|
||||
ggml_vec_scale_f32(nek1, S, scale);
|
||||
ggml_vec_scale_f32(masked_begin, S, scale);
|
||||
|
||||
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
||||
for (int64_t i = masked_begin; i < M; i++) {
|
||||
S[i] = -INFINITY;
|
||||
}
|
||||
|
@ -15195,7 +15195,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
S[i] = -INFINITY;
|
||||
}
|
||||
|
||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
||||
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
||||
// k indices
|
||||
const int ik1 = ic;
|
||||
|
||||
|
@ -15209,9 +15210,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
}
|
||||
|
||||
// scale
|
||||
ggml_vec_scale_f32(nek1, S, scale);
|
||||
ggml_vec_scale_f32(masked_begin, S, scale);
|
||||
|
||||
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
||||
for (int64_t i = masked_begin; i < M; i++) {
|
||||
S[i] = -INFINITY;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue