exclude known zero values from computations in flash_attn_f32 & flash_attn_back_f32
This commit is contained in:
parent
7898652dfb
commit
ec57689f64
1 changed files with 41 additions and 44 deletions
85
ggml.c
85
ggml.c
|
@ -14475,20 +14475,17 @@ static void ggml_compute_forward_flash_attn_f32(
|
||||||
// scale
|
// scale
|
||||||
ggml_vec_scale_f32(nek1, S, scale);
|
ggml_vec_scale_f32(nek1, S, scale);
|
||||||
|
|
||||||
if (masked) {
|
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
||||||
for (int64_t i = P; i < M; i++) {
|
for (int64_t i = masked_begin; i < M; i++) {
|
||||||
if (i > P + iq1) {
|
S[i] = -INFINITY;
|
||||||
S[i] = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// softmax
|
// softmax
|
||||||
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
// exclude known -INF S[..] values from max and loop
|
||||||
// dont forget to set their SW values to zero
|
// dont forget to set their SW values to zero
|
||||||
{
|
{
|
||||||
float max = -INFINITY;
|
float max = -INFINITY;
|
||||||
ggml_vec_max_f32(M, &max, S);
|
ggml_vec_max_f32(masked_begin, &max, S);
|
||||||
|
|
||||||
ggml_float sum = 0.0;
|
ggml_float sum = 0.0;
|
||||||
{
|
{
|
||||||
|
@ -14502,10 +14499,15 @@ static void ggml_compute_forward_flash_attn_f32(
|
||||||
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
||||||
|
|
||||||
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
||||||
|
if (i >= masked_begin) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
float * SS = S + i;
|
float * SS = S + i;
|
||||||
|
|
||||||
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
||||||
if (SS[j] == -INFINITY) {
|
if (i + j >= masked_begin) {
|
||||||
|
break;
|
||||||
|
} else if (SS[j] == -INFINITY) {
|
||||||
SS[j] = 0.0f;
|
SS[j] = 0.0f;
|
||||||
} else {
|
} else {
|
||||||
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
||||||
|
@ -14530,10 +14532,10 @@ static void ggml_compute_forward_flash_attn_f32(
|
||||||
assert(sum > 0.0);
|
assert(sum > 0.0);
|
||||||
|
|
||||||
sum = 1.0/sum;
|
sum = 1.0/sum;
|
||||||
ggml_vec_scale_f32(M, S, sum);
|
ggml_vec_scale_f32(masked_begin, S, sum);
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
for (int i = 0; i < M; ++i) {
|
for (int i = 0; i < masked_begin; ++i) {
|
||||||
assert(!isnan(S[i]));
|
assert(!isnan(S[i]));
|
||||||
assert(!isinf(S[i]));
|
assert(!isinf(S[i]));
|
||||||
}
|
}
|
||||||
|
@ -14550,7 +14552,7 @@ static void ggml_compute_forward_flash_attn_f32(
|
||||||
const int iv2 = iq2 % nev2;
|
const int iv2 = iq2 % nev2;
|
||||||
const int iv3 = iq3;
|
const int iv3 = iq3;
|
||||||
|
|
||||||
ggml_vec_dot_f32(nev0,
|
ggml_vec_dot_f32(masked_begin,
|
||||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||||
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||||
S);
|
S);
|
||||||
|
@ -15126,20 +15128,17 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
// scale
|
// scale
|
||||||
ggml_vec_scale_f32(nek1, S, scale);
|
ggml_vec_scale_f32(nek1, S, scale);
|
||||||
|
|
||||||
if (masked) {
|
const int64_t masked_begin = masked ? (P + iq1 + 1) : M;
|
||||||
for (int64_t i = P; i < M; i++) {
|
for (int64_t i = masked_begin; i < M; i++) {
|
||||||
if (i > P + iq1) {
|
S[i] = -INFINITY;
|
||||||
S[i] = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// softmax
|
// softmax
|
||||||
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
// exclude known -INF S[..] values from max and loop
|
||||||
// dont forget to set their SM values to zero
|
// dont forget to set their SM values to zero
|
||||||
{
|
{
|
||||||
float max = -INFINITY;
|
float max = -INFINITY;
|
||||||
ggml_vec_max_f32(M, &max, S);
|
ggml_vec_max_f32(masked_begin, &max, S);
|
||||||
|
|
||||||
ggml_float sum = 0.0;
|
ggml_float sum = 0.0;
|
||||||
{
|
{
|
||||||
|
@ -15153,11 +15152,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
||||||
|
|
||||||
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
||||||
|
if (i >= masked_begin) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
float * SR = S + i;
|
float * SR = S + i;
|
||||||
float * SW = SM + i;
|
float * SW = SM + i;
|
||||||
|
|
||||||
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
||||||
if (SR[j] == -INFINITY) {
|
if (i + j >= masked_begin) {
|
||||||
|
break;
|
||||||
|
} else if (SR[j] == -INFINITY) {
|
||||||
SW[j] = 0.0f;
|
SW[j] = 0.0f;
|
||||||
} else {
|
} else {
|
||||||
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
||||||
|
@ -15182,7 +15186,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
assert(sum > 0.0);
|
assert(sum > 0.0);
|
||||||
|
|
||||||
sum = 1.0/sum;
|
sum = 1.0/sum;
|
||||||
ggml_vec_scale_f32(M, SM, sum);
|
ggml_vec_scale_f32(masked_begin, SM, sum);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -15253,9 +15257,10 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
// S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
|
// S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
|
||||||
// for ic:
|
// for ic:
|
||||||
// S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
|
// S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
|
||||||
ggml_vec_set_f32(M, S, 0);
|
// exclude known future zero S[..] values from operation
|
||||||
|
ggml_vec_set_f32(masked_begin, S, 0);
|
||||||
for (int64_t ic = 0; ic < D; ++ic) {
|
for (int64_t ic = 0; ic < D; ++ic) {
|
||||||
ggml_vec_mad_f32(M,
|
ggml_vec_mad_f32(masked_begin,
|
||||||
S,
|
S,
|
||||||
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||||
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
|
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
|
||||||
|
@ -15263,23 +15268,15 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
|
|
||||||
// S = SM * (S - dot(SM, S))
|
// S = SM * (S - dot(SM, S))
|
||||||
float dot_SM_gradSM = 0;
|
float dot_SM_gradSM = 0;
|
||||||
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
|
||||||
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
||||||
ggml_vec_mul_f32 (M, S, S, SM);
|
ggml_vec_mul_f32 (masked_begin, S, S, SM);
|
||||||
|
|
||||||
// S = diag_mask_zero(S, P) * scale
|
// S = diag_mask_zero(S, P) * scale
|
||||||
if (masked) {
|
// already done by above ggml_vec_set_f32
|
||||||
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
|
||||||
// S[i] = 0;
|
// exclude known zero S[..] values from operation
|
||||||
// }
|
ggml_vec_scale_f32(masked_begin, S, scale);
|
||||||
for (int64_t i = P; i < M; i++) {
|
|
||||||
if (i > P + iq1) {
|
|
||||||
S[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// todo: exclude known zero S[..] values from operation
|
|
||||||
ggml_vec_scale_f32(M, S, scale);
|
|
||||||
|
|
||||||
// S shape [M,1]
|
// S shape [M,1]
|
||||||
// SM shape [M,1]
|
// SM shape [M,1]
|
||||||
|
@ -15291,8 +15288,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
||||||
// for ic:
|
// for ic:
|
||||||
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
|
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
|
||||||
// todo: exclude known zero S[..] values from loop
|
// exclude known zero S[..] values from loop
|
||||||
for (int64_t ic = 0; ic < M; ++ic) {
|
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
||||||
ggml_vec_mad_f32(D,
|
ggml_vec_mad_f32(D,
|
||||||
(float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
|
(float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
|
||||||
(float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
(float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||||
|
@ -15303,8 +15300,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
// for ic:
|
// for ic:
|
||||||
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
||||||
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
||||||
// todo: exclude known zero S[..] values from loop
|
// exclude known zero S[..] values from loop
|
||||||
for (int64_t ic = 0; ic < M; ++ic) {
|
for (int64_t ic = 0; ic < masked_begin; ++ic) {
|
||||||
ggml_vec_mad_f32(D,
|
ggml_vec_mad_f32(D,
|
||||||
(float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
|
(float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
|
||||||
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
|
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
|
||||||
|
@ -15315,9 +15312,9 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
// for ic:
|
// for ic:
|
||||||
// grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
|
// grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
|
||||||
// grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
|
// grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
|
||||||
// todo: exclude known zero SM[..] values from mad
|
// exclude known zero SM[..] values from mad
|
||||||
for (int64_t ic = 0; ic < D; ++ic) {
|
for (int64_t ic = 0; ic < D; ++ic) {
|
||||||
ggml_vec_mad_f32(M,
|
ggml_vec_mad_f32(masked_begin,
|
||||||
(float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
|
(float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
|
||||||
SM,
|
SM,
|
||||||
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
|
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue