bugfixes for backward pass of flash attention
This commit is contained in:
parent
22a7279ffb
commit
38560b6d51
1 changed files with 253 additions and 234 deletions
93
ggml.c
93
ggml.c
|
@ -6221,7 +6221,6 @@ struct ggml_tensor * ggml_flash_attn(
|
|||
bool is_node = false;
|
||||
|
||||
if (q->grad || k->grad || v->grad) {
|
||||
GGML_ASSERT(false); // TODO: implement backward
|
||||
is_node = true;
|
||||
}
|
||||
|
||||
|
@ -12882,10 +12881,15 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
//const int64_t nev2 = v->ne[2];
|
||||
//const int64_t nev3 = v->ne[3];
|
||||
|
||||
const int64_t ned0 = d->ne[0];
|
||||
const int64_t ned1 = d->ne[1];
|
||||
//const int64_t ned2 = d->ne[2];
|
||||
//const int64_t ned3 = d->ne[3];
|
||||
|
||||
const int64_t ne0 = dst->ne[0];
|
||||
const int64_t ne1 = dst->ne[1];
|
||||
//const int64_t ne2 = dst->ne[2];
|
||||
//const int64_t ne3 = dst->ne[3];
|
||||
const int64_t ne2 = dst->ne[2];
|
||||
const int64_t ne3 = dst->ne[3];
|
||||
|
||||
const int nbk0 = k->nb[0];
|
||||
const int nbk1 = k->nb[1];
|
||||
|
@ -12923,8 +12927,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
||||
const int mxDM = MAX(D, Mup);
|
||||
|
||||
GGML_ASSERT(ne0 == D);
|
||||
GGML_ASSERT(ne1 == N);
|
||||
// GGML_ASSERT(ne0 == D);
|
||||
// GGML_ASSERT(ne1 == N);
|
||||
GGML_ASSERT(P >= 0);
|
||||
|
||||
GGML_ASSERT(nbq0 == sizeof(float));
|
||||
|
@ -12934,10 +12938,12 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
GGML_ASSERT(neq0 == D);
|
||||
GGML_ASSERT(nek0 == D);
|
||||
GGML_ASSERT(nev1 == D);
|
||||
GGML_ASSERT(ned0 == D);
|
||||
|
||||
GGML_ASSERT(neq1 == N);
|
||||
GGML_ASSERT(nek1 == N + P);
|
||||
GGML_ASSERT(nev1 == D);
|
||||
GGML_ASSERT(ned1 == N);
|
||||
|
||||
// dst cannot be transposed or permuted
|
||||
GGML_ASSERT(nb0 == sizeof(float));
|
||||
|
@ -12946,6 +12952,9 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
GGML_ASSERT(nb2 <= nb3);
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
if (ith == 0) {
|
||||
memset(dst->data, 0, nb0*ne0*ne1*ne2*ne3);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -12956,7 +12965,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
// parallelize by q rows using ggml_vec_dot_f32
|
||||
|
||||
// total rows in q
|
||||
const int nr = neq1*neq2*neq3;
|
||||
const int nr = neq2*neq3;
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
@ -12971,9 +12980,10 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
// q indices
|
||||
const int iq3 = ir/(neq2*neq1);
|
||||
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||
const int iq3 = ir/(neq2);
|
||||
const int iq2 = (ir - iq3*neq2)/neq2;
|
||||
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
|
||||
|
||||
|
||||
// not sure about CACHE_LINE_SIZE_F32..
|
||||
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
||||
|
@ -13027,17 +13037,18 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
||||
|
||||
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
||||
float * SS = SM + i;
|
||||
float * SR = S + i;
|
||||
float * SW = SM + i;
|
||||
|
||||
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
||||
if (SS[j] == -INFINITY) {
|
||||
SS[j] = 0.0f;
|
||||
if (SR[j] == -INFINITY) {
|
||||
SW[j] = 0.0f;
|
||||
} else {
|
||||
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
||||
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
||||
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
||||
sump[j] += (ggml_float)val;
|
||||
SS[j] = val;
|
||||
SW[j] = val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -13118,8 +13129,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
|
||||
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
||||
// S = d[:D,iq1,iq2,iq3] @ vcur
|
||||
// S[:M] += vcur[:,ic] * d[ic,iq1,iq2,iq3]
|
||||
ggml_vec_set_f32(D, S, 0);
|
||||
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
|
||||
ggml_vec_set_f32(M, S, 0);
|
||||
for (int64_t ic = 0; ic < D; ++ic) {
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
|
@ -13129,7 +13140,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
ggml_vec_mad_f32(M,
|
||||
S,
|
||||
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
||||
*(float *) ((char *) d->data + (ic*nbd1 + i1*nbd2 + i2*nbd2 + i3*nbd3)));
|
||||
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
||||
}
|
||||
|
||||
// S = SM * (S - dot(SM, S))
|
||||
|
@ -13140,10 +13151,15 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
|
||||
// S = diag_mask_zero(S, P) * scale
|
||||
if (masked) {
|
||||
for (int64_t i = P + iq1 + 1; i < M; i++) {
|
||||
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
||||
// S[i] = 0;
|
||||
// }
|
||||
for (int64_t i = P; i < M; i++) {
|
||||
if (i > P + iq1) {
|
||||
S[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_vec_scale_f32(M, S, scale);
|
||||
|
||||
void * grad_q = (char *) dst->data;
|
||||
|
@ -13170,18 +13186,20 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
//
|
||||
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
||||
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
||||
// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
||||
// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
||||
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
|
||||
//
|
||||
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
||||
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
||||
for (int64_t ic = 0; ic < M; ++ic) {
|
||||
// dst indices
|
||||
const int i1 = iq1;
|
||||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
ggml_vec_dot_f32(D,
|
||||
(float *) ((char *) grad_q + (ic*nb0 + i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
|
||||
ggml_vec_mad_f32(D,
|
||||
(float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
|
||||
(float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
|
||||
S);
|
||||
S[ic]);
|
||||
}
|
||||
|
||||
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
||||
|
@ -13193,12 +13211,12 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
ggml_vec_set_f32(D,
|
||||
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
||||
0);
|
||||
// ggml_vec_set_f32(D,
|
||||
// (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
||||
// 0);
|
||||
ggml_vec_mad_f32(D,
|
||||
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
||||
(float *) ((char *) q->data + (i1*nbk1 + i2*nbk2 + i3*nbk3)),
|
||||
(float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
|
||||
S[ic]);
|
||||
}
|
||||
|
||||
|
@ -13211,13 +13229,14 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
const int i2 = iq2;
|
||||
const int i3 = iq3;
|
||||
|
||||
ggml_vec_set_f32(M,
|
||||
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
||||
0);
|
||||
// ggml_vec_set_f32(M,
|
||||
// (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
||||
// 0);
|
||||
ggml_vec_mad_f32(M,
|
||||
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
||||
SM,
|
||||
*(float *) ((char *) d->data + (ic*nbd1 + i1*nbd2 + i2*nbd2 + i3*nbd3)));
|
||||
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -14475,9 +14494,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
bool masked = t != 0;
|
||||
flash_grad =
|
||||
ggml_flash_attn_back(ctx,
|
||||
src0->grad,
|
||||
src1->grad,
|
||||
tensor->opt[0]->grad,
|
||||
src0,
|
||||
src1,
|
||||
tensor->opt[0],
|
||||
tensor->grad,
|
||||
masked);
|
||||
}
|
||||
|
@ -14509,7 +14528,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
} break;
|
||||
case 4:
|
||||
{
|
||||
grad_q = ggml_view_3d(ctx,
|
||||
grad_q = ggml_view_4d(ctx,
|
||||
flash_grad,
|
||||
src0->ne[0],
|
||||
src0->ne[1],
|
||||
|
@ -14555,7 +14574,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
} break;
|
||||
case 4:
|
||||
{
|
||||
grad_k = ggml_view_3d(ctx,
|
||||
grad_k = ggml_view_4d(ctx,
|
||||
flash_grad,
|
||||
src1->ne[0],
|
||||
src1->ne[1],
|
||||
|
@ -14604,7 +14623,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
} break;
|
||||
case 4:
|
||||
{
|
||||
grad_v = ggml_view_3d(ctx,
|
||||
grad_v = ggml_view_4d(ctx,
|
||||
flash_grad,
|
||||
opt0->ne[0],
|
||||
opt0->ne[1],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue