exclude known zero values from computations in flash_attn_f32 & flash_attn_back_f32

This commit is contained in:
xaedes 2023-09-13 18:34:06 +02:00
parent 7898652dfb
commit ec57689f64
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

85
ggml.c
View file

@ -14475,20 +14475,17 @@ static void ggml_compute_forward_flash_attn_f32(
// scale
ggml_vec_scale_f32(nek1, S, scale);
if (masked) {
for (int64_t i = P; i < M; i++) {
if (i > P + iq1) {
S[i] = -INFINITY;
}
}
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
for (int64_t i = masked_begin; i < M; i++) {
S[i] = -INFINITY;
}
// softmax
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
// exclude known -INF S[..] values from max and loop
// dont forget to set their SW values to zero
{
float max = -INFINITY;
ggml_vec_max_f32(M, &max, S);
ggml_vec_max_f32(masked_begin, &max, S);
ggml_float sum = 0.0;
{
@ -14502,10 +14499,15 @@ static void ggml_compute_forward_flash_attn_f32(
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
if (i >= masked_begin) {
break;
}
float * SS = S + i;
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
if (SS[j] == -INFINITY) {
if (i + j >= masked_begin) {
break;
} else if (SS[j] == -INFINITY) {
SS[j] = 0.0f;
} else {
#ifndef GGML_FLASH_ATTN_EXP_FP16
@ -14530,10 +14532,10 @@ static void ggml_compute_forward_flash_attn_f32(
assert(sum > 0.0);
sum = 1.0/sum;
ggml_vec_scale_f32(M, S, sum);
ggml_vec_scale_f32(masked_begin, S, sum);
#ifndef NDEBUG
for (int i = 0; i < M; ++i) {
for (int i = 0; i < masked_begin; ++i) {
assert(!isnan(S[i]));
assert(!isinf(S[i]));
}
@ -14550,7 +14552,7 @@ static void ggml_compute_forward_flash_attn_f32(
const int iv2 = iq2 % nev2;
const int iv3 = iq3;
ggml_vec_dot_f32(nev0,
ggml_vec_dot_f32(masked_begin,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S);
@ -15126,20 +15128,17 @@ static void ggml_compute_forward_flash_attn_back_f32(
// scale
ggml_vec_scale_f32(nek1, S, scale);
if (masked) {
for (int64_t i = P; i < M; i++) {
if (i > P + iq1) {
S[i] = -INFINITY;
}
}
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
for (int64_t i = masked_begin; i < M; i++) {
S[i] = -INFINITY;
}
// softmax
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
// exclude known -INF S[..] values from max and loop
// dont forget to set their SM values to zero
{
float max = -INFINITY;
ggml_vec_max_f32(M, &max, S);
ggml_vec_max_f32(masked_begin, &max, S);
ggml_float sum = 0.0;
{
@ -15153,11 +15152,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
if (i >= masked_begin) {
break;
}
float * SR = S + i;
float * SW = SM + i;
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
if (SR[j] == -INFINITY) {
if (i + j >= masked_begin) {
break;
} else if (SR[j] == -INFINITY) {
SW[j] = 0.0f;
} else {
#ifndef GGML_FLASH_ATTN_EXP_FP16
@ -15182,7 +15186,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
assert(sum > 0.0);
sum = 1.0/sum;
ggml_vec_scale_f32(M, SM, sum);
ggml_vec_scale_f32(masked_begin, SM, sum);
}
@ -15253,9 +15257,10 @@ static void ggml_compute_forward_flash_attn_back_f32(
// S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
// for ic:
// S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
ggml_vec_set_f32(M, S, 0);
// exclude known future zero S[..] values from operation
ggml_vec_set_f32(masked_begin, S, 0);
for (int64_t ic = 0; ic < D; ++ic) {
ggml_vec_mad_f32(M,
ggml_vec_mad_f32(masked_begin,
S,
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
@ -15263,23 +15268,15 @@ static void ggml_compute_forward_flash_attn_back_f32(
// S = SM * (S - dot(SM, S))
float dot_SM_gradSM = 0;
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
ggml_vec_mul_f32 (M, S, S, SM);
ggml_vec_mul_f32 (masked_begin, S, S, SM);
// S = diag_mask_zero(S, P) * scale
if (masked) {
// for (int64_t i = P + iq1 + 1; i < M; i++) {
// S[i] = 0;
// }
for (int64_t i = P; i < M; i++) {
if (i > P + iq1) {
S[i] = 0;
}
}
}
// todo: exclude known zero S[..] values from operation
ggml_vec_scale_f32(M, S, scale);
// already done by above ggml_vec_set_f32
// exclude known zero S[..] values from operation
ggml_vec_scale_f32(masked_begin, S, scale);
// S shape [M,1]
// SM shape [M,1]
@ -15291,8 +15288,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
// for ic:
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
// todo: exclude known zero S[..] values from loop
for (int64_t ic = 0; ic < M; ++ic) {
// exclude known zero S[..] values from loop
for (int64_t ic = 0; ic < masked_begin; ++ic) {
ggml_vec_mad_f32(D,
(float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
(float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
@ -15303,8 +15300,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
// for ic:
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
// todo: exclude known zero S[..] values from loop
for (int64_t ic = 0; ic < M; ++ic) {
// exclude known zero S[..] values from loop
for (int64_t ic = 0; ic < masked_begin; ++ic) {
ggml_vec_mad_f32(D,
(float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
@ -15315,9 +15312,9 @@ static void ggml_compute_forward_flash_attn_back_f32(
// for ic:
// grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
// grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
// todo: exclude known zero SM[..] values from mad
// exclude known zero SM[..] values from mad
for (int64_t ic = 0; ic < D; ++ic) {
ggml_vec_mad_f32(M,
ggml_vec_mad_f32(masked_begin,
(float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
SM,
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));