bugfixes for backward pass of flash attention

This commit is contained in:
xaedes 2023-05-29 23:45:58 +02:00
parent 22a7279ffb
commit 38560b6d51
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

487
ggml.c
View file

@ -6221,7 +6221,6 @@ struct ggml_tensor * ggml_flash_attn(
bool is_node = false; bool is_node = false;
if (q->grad || k->grad || v->grad) { if (q->grad || k->grad || v->grad) {
GGML_ASSERT(false); // TODO: implement backward
is_node = true; is_node = true;
} }
@ -12882,10 +12881,15 @@ static void ggml_compute_forward_flash_attn_back_f32(
//const int64_t nev2 = v->ne[2]; //const int64_t nev2 = v->ne[2];
//const int64_t nev3 = v->ne[3]; //const int64_t nev3 = v->ne[3];
const int64_t ned0 = d->ne[0];
const int64_t ned1 = d->ne[1];
//const int64_t ned2 = d->ne[2];
//const int64_t ned3 = d->ne[3];
const int64_t ne0 = dst->ne[0]; const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1]; const int64_t ne1 = dst->ne[1];
//const int64_t ne2 = dst->ne[2]; const int64_t ne2 = dst->ne[2];
//const int64_t ne3 = dst->ne[3]; const int64_t ne3 = dst->ne[3];
const int nbk0 = k->nb[0]; const int nbk0 = k->nb[0];
const int nbk1 = k->nb[1]; const int nbk1 = k->nb[1];
@ -12923,8 +12927,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL); const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
const int mxDM = MAX(D, Mup); const int mxDM = MAX(D, Mup);
GGML_ASSERT(ne0 == D); // GGML_ASSERT(ne0 == D);
GGML_ASSERT(ne1 == N); // GGML_ASSERT(ne1 == N);
GGML_ASSERT(P >= 0); GGML_ASSERT(P >= 0);
GGML_ASSERT(nbq0 == sizeof(float)); GGML_ASSERT(nbq0 == sizeof(float));
@ -12934,10 +12938,12 @@ static void ggml_compute_forward_flash_attn_back_f32(
GGML_ASSERT(neq0 == D); GGML_ASSERT(neq0 == D);
GGML_ASSERT(nek0 == D); GGML_ASSERT(nek0 == D);
GGML_ASSERT(nev1 == D); GGML_ASSERT(nev1 == D);
GGML_ASSERT(ned0 == D);
GGML_ASSERT(neq1 == N); GGML_ASSERT(neq1 == N);
GGML_ASSERT(nek1 == N + P); GGML_ASSERT(nek1 == N + P);
GGML_ASSERT(nev1 == D); GGML_ASSERT(nev1 == D);
GGML_ASSERT(ned1 == N);
// dst cannot be transposed or permuted // dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float));
@ -12946,6 +12952,9 @@ static void ggml_compute_forward_flash_attn_back_f32(
GGML_ASSERT(nb2 <= nb3); GGML_ASSERT(nb2 <= nb3);
if (params->type == GGML_TASK_INIT) { if (params->type == GGML_TASK_INIT) {
if (ith == 0) {
memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
}
return; return;
} }
@ -12956,7 +12965,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
// parallelize by q rows using ggml_vec_dot_f32 // parallelize by q rows using ggml_vec_dot_f32
// total rows in q // total rows in q
const int nr = neq1*neq2*neq3; const int nr = neq2*neq3;
// rows per thread // rows per thread
const int dr = (nr + nth - 1)/nth; const int dr = (nr + nth - 1)/nth;
@ -12971,253 +12980,263 @@ static void ggml_compute_forward_flash_attn_back_f32(
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// q indices // q indices
const int iq3 = ir/(neq2*neq1); const int iq3 = ir/(neq2);
const int iq2 = (ir - iq3*neq2*neq1)/neq1; const int iq2 = (ir - iq3*neq2)/neq2;
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); for ( int iq1 = 0; iq1 < neq1; ++iq1) {
// not sure about CACHE_LINE_SIZE_F32..
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
for (int i = M; i < Mup; ++i) { // not sure about CACHE_LINE_SIZE_F32..
S[i] = -INFINITY; // - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
} float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
for (int64_t ic = 0; ic < nek1; ++ic) { for (int i = M; i < Mup; ++i) {
// k indices S[i] = -INFINITY;
const int ik3 = iq3;
const int ik2 = iq2;
const int ik1 = ic;
// S indices
const int i1 = ik1;
ggml_vec_dot_f32(neq0,
S + i1,
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
// scale
ggml_vec_scale_f32(nek1, S, scale);
if (masked) {
for (int64_t i = P; i < M; i++) {
if (i > P + iq1) {
S[i] = -INFINITY;
}
} }
}
// softmax for (int64_t ic = 0; ic < nek1; ++ic) {
{ // k indices
float max = -INFINITY; const int ik3 = iq3;
ggml_vec_max_f32(M, &max, S); const int ik2 = iq2;
const int ik1 = ic;
ggml_float sum = 0.0; // S indices
{ const int i1 = ik1;
#ifdef GGML_SOFT_MAX_ACCELERATE
max = -max;
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
vvexpf(SM, SM, &Mup);
ggml_vec_sum_f32(Mup, &sum, SM);
#else
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) { ggml_vec_dot_f32(neq0,
float * SS = SM + i; S + i1,
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
}
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) { // scale
if (SS[j] == -INFINITY) { ggml_vec_scale_f32(nek1, S, scale);
SS[j] = 0.0f;
} else { if (masked) {
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max); for (int64_t i = P; i < M; i++) {
memcpy(&scvt[j], &s, sizeof(uint16_t)); if (i > P + iq1) {
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]); S[i] = -INFINITY;
sump[j] += (ggml_float)val;
SS[j] = val;
}
} }
} }
}
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) { // softmax
sum += sump[i]; {
} float max = -INFINITY;
ggml_vec_max_f32(M, &max, S);
ggml_float sum = 0.0;
{
#ifdef GGML_SOFT_MAX_ACCELERATE
max = -max;
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
vvexpf(SM, SM, &Mup);
ggml_vec_sum_f32(Mup, &sum, SM);
#else
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
float * SR = S + i;
float * SW = SM + i;
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
if (SR[j] == -INFINITY) {
SW[j] = 0.0f;
} else {
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
memcpy(&scvt[j], &s, sizeof(uint16_t));
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
sump[j] += (ggml_float)val;
SW[j] = val;
}
}
}
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
sum += sump[i];
}
#endif #endif
}
assert(sum > 0.0);
sum = 1.0/sum;
ggml_vec_scale_f32(M, SM, sum);
} }
assert(sum > 0.0); // step-by-step explanation
{
sum = 1.0/sum; // forward-process shape grads from backward process
ggml_vec_scale_f32(M, SM, sum); // parallel_for iq2,iq3:
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
} // q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
// step-by-step explanation // for iq1:
{ // kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
// forward-process shape grads from backward process // qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
// parallel_for iq2,iq3: // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur] // S0 = -Inf [D,1,1,1]
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur] // ~S1[i] = dot(kcur[:D,i], qcur)
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur] // S1 = qcur.T @ kcur [M,1,1,1] grad[S1] = grad[S2] * scale
// for iq1: // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 // ~S5[i] = dot(vcur[:,i],S4)
// S0 = -Inf [D,1,1,1] // S5 = S4.T @ vcur [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
// ~S1[i] = dot(kcur[:D,i], qcur) // ~dst[i,iq1,iq2,iq3] = S5[i] ^
// S1 = qcur.T @ kcur [M,1,1,1] grad[S1] = grad[S2] * scale // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) // dst backward-/ grad[dst] = d
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) //
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur // output gradients with their dependencies:
// ~S5[i] = dot(vcur[:,i],S4) //
// S5 = S4.T @ vcur [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3] // grad[kcur] = grad[S1].T @ qcur
// ~dst[i,iq1,iq2,iq3] = S5[i] ^ // grad[S1] = diag_mask_zero(grad[S3], P) * scale
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3] // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
// dst backward-/ grad[dst] = d // grad[S4] = grad[S5] @ vcur
// // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
// output gradients with their dependencies: // grad[qcur] = grad[S1] @ kcur
// // grad[vcur] = grad[S5].T @ S4
// grad[kcur] = grad[S1].T @ qcur // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
// grad[S1] = diag_mask_zero(grad[S3], P) * scale //
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) // in post-order:
// grad[S4] = grad[S5] @ vcur //
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur // S1 = qcur.T @ kcur
// grad[qcur] = grad[S1] @ kcur // S2 = S1 * scale
// grad[vcur] = grad[S5].T @ S4 // S3 = diag_mask_inf(S2, P)
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 // S4 = softmax(S3)
// // grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
// in post-order: // grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
// // grad[S1] = diag_mask_zero(grad[S3], P) * scale
// S1 = qcur.T @ kcur // grad[qcur] = grad[S1] @ kcur
// S2 = S1 * scale // grad[kcur] = grad[S1].T @ qcur
// S3 = diag_mask_inf(S2, P) // grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
// S4 = softmax(S3) //
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur // using less variables (SM=S4):
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) //
// grad[S1] = diag_mask_zero(grad[S3], P) * scale // S = diag_mask_inf(qcur.T @ kcur * scale, P)
// grad[qcur] = grad[S1] @ kcur // SM = softmax(S)
// grad[kcur] = grad[S1].T @ qcur // S = d[:D,iq1,iq2,iq3] @ vcur
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4 // dot_SM_gradSM = dot(SM, S)
// // S = SM * (S - dot(SM, S))
// using less variables (SM=S4): // S = diag_mask_zero(S, P) * scale
// //
// S = diag_mask_inf(qcur.T @ kcur * scale, P) // grad[q][:D,iq1,iq2,iq3] += S @ kcur
// SM = softmax(S) // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
// S = d[:D,iq1,iq2,iq3] @ vcur // grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
// dot_SM_gradSM = dot(SM, S)
// S = SM * (S - dot(SM, S))
// S = diag_mask_zero(S, P) * scale
//
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
}
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
// S = d[:D,iq1,iq2,iq3] @ vcur
// S[:M] += vcur[:,ic] * d[ic,iq1,iq2,iq3]
ggml_vec_set_f32(D, S, 0);
for (int64_t ic = 0; ic < D; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_mad_f32(M,
S,
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
*(float *) ((char *) d->data + (ic*nbd1 + i1*nbd2 + i2*nbd2 + i3*nbd3)));
}
// S = SM * (S - dot(SM, S))
float dot_SM_gradSM = 0;
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
ggml_vec_mul_f32 (M, S, S, SM);
// S = diag_mask_zero(S, P) * scale
if (masked) {
for (int64_t i = P + iq1 + 1; i < M; i++) {
S[i] = 0;
} }
}
ggml_vec_scale_f32(M, S, scale);
void * grad_q = (char *) dst->data; // S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3; // S = d[:D,iq1,iq2,iq3] @ vcur
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3; // S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
ggml_vec_set_f32(M, S, 0);
for (int64_t ic = 0; ic < D; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
const size_t nbgq1 = nb0*neq0; ggml_vec_mad_f32(M,
const size_t nbgq2 = nb0*neq0*neq1; S,
const size_t nbgq3 = nb0*neq0*neq1*neq2; (float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
}
const size_t nbgk1 = nb0*nek0; // S = SM * (S - dot(SM, S))
const size_t nbgk2 = nb0*nek0*nek1; float dot_SM_gradSM = 0;
const size_t nbgk3 = nb0*nek0*nek1*neq2; ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
ggml_vec_mul_f32 (M, S, S, SM);
const size_t nbgv1 = nb0*nev0; // S = diag_mask_zero(S, P) * scale
const size_t nbgv2 = nb0*nev0*nev1; if (masked) {
const size_t nbgv3 = nb0*nev0*nev1*neq2; // for (int64_t i = P + iq1 + 1; i < M; i++) {
// S[i] = 0;
// }
for (int64_t i = P; i < M; i++) {
if (i > P + iq1) {
S[i] = 0;
}
}
}
ggml_vec_scale_f32(M, S, scale);
// S shape [M,1] void * grad_q = (char *) dst->data;
// SM shape [M,1] void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
// kcur shape [D,M] void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
// qcur shape [D,1]
// vcur shape [M,D]
//
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
for (int64_t ic = 0; ic < M; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_dot_f32(D, const size_t nbgq1 = nb0*neq0;
(float *) ((char *) grad_q + (ic*nb0 + i1*nbgq1 + i2*nbgq2 + i3*nbgq3)), const size_t nbgq2 = nb0*neq0*neq1;
(float *) ((char *) k->data + ( ic*nbk1 + i2*nbk2 + i3*nbk3)), const size_t nbgq3 = nb0*neq0*neq1*neq2;
S);
}
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur const size_t nbgk1 = nb0*nek0;
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0] const size_t nbgk2 = nb0*nek0*nek1;
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0] const size_t nbgk3 = nb0*nek0*nek1*neq2;
for (int64_t ic = 0; ic < M; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_set_f32(D, const size_t nbgv1 = nb0*nev0;
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)), const size_t nbgv2 = nb0*nev0*nev1;
0); const size_t nbgv3 = nb0*nev0*nev1*neq2;
ggml_vec_mad_f32(D,
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
(float *) ((char *) q->data + (i1*nbk1 + i2*nbk2 + i3*nbk3)),
S[ic]);
}
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM // S shape [M,1]
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M] // SM shape [M,1]
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M] // kcur shape [D,M]
for (int64_t ic = 0; ic < D; ++ic) { // qcur shape [D,1]
// dst indices // vcur shape [M,D]
const int i1 = iq1; //
const int i2 = iq2; // grad[q][:D,iq1,iq2,iq3] += S @ kcur
const int i3 = iq3; // grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
//
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
for (int64_t ic = 0; ic < M; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_mad_f32(D,
(float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
(float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
S[ic]);
}
ggml_vec_set_f32(M, // grad[k][:D,:M,iq2,iq3] += S.T @ qcur
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), // grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
0); // grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
ggml_vec_mad_f32(M, for (int64_t ic = 0; ic < M; ++ic) {
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)), // dst indices
SM, const int i1 = iq1;
*(float *) ((char *) d->data + (ic*nbd1 + i1*nbd2 + i2*nbd2 + i3*nbd3))); const int i2 = iq2;
const int i3 = iq3;
// ggml_vec_set_f32(D,
// (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
// 0);
ggml_vec_mad_f32(D,
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
(float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
S[ic]);
}
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
for (int64_t ic = 0; ic < D; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// ggml_vec_set_f32(M,
// (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
// 0);
ggml_vec_mad_f32(M,
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
SM,
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
}
} }
} }
} }
@ -14475,9 +14494,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
bool masked = t != 0; bool masked = t != 0;
flash_grad = flash_grad =
ggml_flash_attn_back(ctx, ggml_flash_attn_back(ctx,
src0->grad, src0,
src1->grad, src1,
tensor->opt[0]->grad, tensor->opt[0],
tensor->grad, tensor->grad,
masked); masked);
} }
@ -14509,7 +14528,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break; } break;
case 4: case 4:
{ {
grad_q = ggml_view_3d(ctx, grad_q = ggml_view_4d(ctx,
flash_grad, flash_grad,
src0->ne[0], src0->ne[0],
src0->ne[1], src0->ne[1],
@ -14555,7 +14574,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break; } break;
case 4: case 4:
{ {
grad_k = ggml_view_3d(ctx, grad_k = ggml_view_4d(ctx,
flash_grad, flash_grad,
src1->ne[0], src1->ne[0],
src1->ne[1], src1->ne[1],
@ -14604,7 +14623,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break; } break;
case 4: case 4:
{ {
grad_v = ggml_view_3d(ctx, grad_v = ggml_view_4d(ctx,
flash_grad, flash_grad,
opt0->ne[0], opt0->ne[0],
opt0->ne[1], opt0->ne[1],