From ec57689f6466bdd68480b274e118c95d6164366d Mon Sep 17 00:00:00 2001 From: xaedes Date: Wed, 13 Sep 2023 18:34:06 +0200 Subject: [PATCH] exclude known zero values from computations in flash_attn_f32 & flash_attn_back_f32 --- ggml.c | 85 ++++++++++++++++++++++++++++------------------------------ 1 file changed, 41 insertions(+), 44 deletions(-) diff --git a/ggml.c b/ggml.c index b7dcb2c51..6296809a4 100644 --- a/ggml.c +++ b/ggml.c @@ -14475,20 +14475,17 @@ static void ggml_compute_forward_flash_attn_f32( // scale ggml_vec_scale_f32(nek1, S, scale); - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; } // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // exclude known -INF S[..] values from max and loop // dont forget to set their SW values to zero { float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); + ggml_vec_max_f32(masked_begin, &max, S); ggml_float sum = 0.0; { @@ -14502,10 +14499,15 @@ static void ggml_compute_forward_flash_attn_f32( ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + if (i >= masked_begin) { + break; + } float * SS = S + i; for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SS[j] == -INFINITY) { + if (i + j >= masked_begin) { + break; + } else if (SS[j] == -INFINITY) { SS[j] = 0.0f; } else { #ifndef GGML_FLASH_ATTN_EXP_FP16 @@ -14530,10 +14532,10 @@ static void ggml_compute_forward_flash_attn_f32( assert(sum > 0.0); sum = 1.0/sum; - ggml_vec_scale_f32(M, S, sum); + ggml_vec_scale_f32(masked_begin, S, sum); #ifndef NDEBUG - for (int i = 0; i < M; ++i) { + for (int i = 0; i < masked_begin; ++i) { assert(!isnan(S[i])); assert(!isinf(S[i])); } @@ -14550,7 +14552,7 @@ static void ggml_compute_forward_flash_attn_f32( const int iv2 = iq2 % nev2; const int iv3 = iq3; - ggml_vec_dot_f32(nev0, + ggml_vec_dot_f32(masked_begin, (float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), S); @@ -15126,20 +15128,17 @@ static void ggml_compute_forward_flash_attn_back_f32( // scale ggml_vec_scale_f32(nek1, S, scale); - if (masked) { - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = -INFINITY; - } - } + const int64_t masked_begin = masked ? (P + iq1 + 1) : M; + for (int64_t i = masked_begin; i < M; i++) { + S[i] = -INFINITY; } // softmax - // todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero. + // exclude known -INF S[..] values from max and loop // dont forget to set their SM values to zero { float max = -INFINITY; - ggml_vec_max_f32(M, &max, S); + ggml_vec_max_f32(masked_begin, &max, S); ggml_float sum = 0.0; { @@ -15153,11 +15152,16 @@ static void ggml_compute_forward_flash_attn_back_f32( ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 }; for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { + if (i >= masked_begin) { + break; + } float * SR = S + i; float * SW = SM + i; for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { - if (SR[j] == -INFINITY) { + if (i + j >= masked_begin) { + break; + } else if (SR[j] == -INFINITY) { SW[j] = 0.0f; } else { #ifndef GGML_FLASH_ATTN_EXP_FP16 @@ -15182,7 +15186,7 @@ static void ggml_compute_forward_flash_attn_back_f32( assert(sum > 0.0); sum = 1.0/sum; - ggml_vec_scale_f32(M, SM, sum); + ggml_vec_scale_f32(masked_begin, SM, sum); } @@ -15253,9 +15257,10 @@ static void ggml_compute_forward_flash_attn_back_f32( // S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3] // for ic: // S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3] - ggml_vec_set_f32(M, S, 0); + // exclude known future zero S[..] values from operation + ggml_vec_set_f32(masked_begin, S, 0); for (int64_t ic = 0; ic < D; ++ic) { - ggml_vec_mad_f32(M, + ggml_vec_mad_f32(masked_begin, S, (float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3))); @@ -15263,23 +15268,15 @@ static void ggml_compute_forward_flash_attn_back_f32( // S = SM * (S - dot(SM, S)) float dot_SM_gradSM = 0; - ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S); + ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S); ggml_vec_acc1_f32(M, S, -dot_SM_gradSM); - ggml_vec_mul_f32 (M, S, S, SM); + ggml_vec_mul_f32 (masked_begin, S, S, SM); // S = diag_mask_zero(S, P) * scale - if (masked) { - // for (int64_t i = P + iq1 + 1; i < M; i++) { - // S[i] = 0; - // } - for (int64_t i = P; i < M; i++) { - if (i > P + iq1) { - S[i] = 0; - } - } - } - // todo: exclude known zero S[..] values from operation - ggml_vec_scale_f32(M, S, scale); + // already done by above ggml_vec_set_f32 + + // exclude known zero S[..] values from operation + ggml_vec_scale_f32(masked_begin, S, scale); // S shape [M,1] // SM shape [M,1] @@ -15291,8 +15288,8 @@ static void ggml_compute_forward_flash_attn_back_f32( // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M] // for ic: // grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3] - // todo: exclude known zero S[..] values from loop - for (int64_t ic = 0; ic < M; ++ic) { + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { ggml_vec_mad_f32(D, (float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)), (float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)), @@ -15303,8 +15300,8 @@ static void ggml_compute_forward_flash_attn_back_f32( // for ic: // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] - // todo: exclude known zero S[..] values from loop - for (int64_t ic = 0; ic < M; ++ic) { + // exclude known zero S[..] values from loop + for (int64_t ic = 0; ic < masked_begin; ++ic) { ggml_vec_mad_f32(D, (float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)), (float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), @@ -15315,9 +15312,9 @@ static void ggml_compute_forward_flash_attn_back_f32( // for ic: // grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M] // grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M] - // todo: exclude known zero SM[..] values from mad + // exclude known zero SM[..] values from mad for (int64_t ic = 0; ic < D; ++ic) { - ggml_vec_mad_f32(M, + ggml_vec_mad_f32(masked_begin, (float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)), SM, *(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));