support grouped-query-attention in ggml_flash_attn and ggml_flash_attn_back

k and v can now be repeated in q along ne[2]

in forward pass just use modulo to compute k and v indices, like ik2 = iq2 % nek2.

in backard pass this won't work as easy, because multiple threads will compete to accumulate to the same k->grad[:,ik1,ik2,ik3] and v->grad[:,iv1,iv2,iv3].
so we change the parallelization over q rows to be over k rows. this ensures non-overlapping (ik2,ik3) across threads.
in each thread we then iterate over the number of repetitions of k/v in q to compute iq2 as iq2 = ik2 + irep*nek2.

since ne2 is not the same for q,k and v we also change how the gradients are concatenated into the result tensor.
additionally the offsets of gradq, gradk and gradv in the result tensor are now memory aligned.

we also simplify the compute_backward part of flash_attn to use ggml_reshape instead of switching over the number of dimensions.
this needs a small change to ggml_reshape, removing the assertion of second argument to be contiguous.
since only the shape (ne) of the second reshape argument is of relevance, its memory layout (nb) is irrelevant -> it can very well be non-contiguous.

change test-grad0 to also test for repeated k/v in q.

this changes the rng and now results in small gradient differences in softmax. these solely come from using f16 exp table lookup in forward softmax: when temporarily changing softmax to use actual exp function, the reported gradient differences go away. gradient differences coming solely from f16 table lookup are acceptable.
added a note to explain this.
This commit is contained in:
xaedes 2023-09-09 17:01:54 +02:00
parent 0c2c9c7545
commit d7aade7d8a
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 353 additions and 412 deletions

388
ggml.c
View file

@ -6585,7 +6585,7 @@ struct ggml_tensor * ggml_reshape(
struct ggml_tensor * a,
struct ggml_tensor * b) {
GGML_ASSERT(ggml_is_contiguous(a));
GGML_ASSERT(ggml_is_contiguous(b));
// as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
bool is_node = false;
@ -7655,27 +7655,30 @@ struct ggml_tensor * ggml_flash_attn_back(
// d shape [D,N,ne2,ne3]
// q shape [D,N,ne2,ne3]
// k shape [D,M,ne2,ne3]
// v shape [M,D,ne2,ne3]
// k shape [D,M,kvne2,ne3]
// v shape [M,D,kvne2,ne3]
const int64_t D = q->ne[0];
const int64_t N = q->ne[1];
const int64_t M = k->ne[1];
const int64_t ne2 = q->ne[2];
const int64_t ne3 = q->ne[3];
const int64_t kvne2 = k->ne[2];
GGML_ASSERT(k->ne[0] == D);
GGML_ASSERT(v->ne[0] == M);
GGML_ASSERT(v->ne[1] == D);
GGML_ASSERT(d->ne[0] == D);
GGML_ASSERT(d->ne[1] == N);
GGML_ASSERT(k->ne[2] == ne2);
GGML_ASSERT(k->ne[2] == kvne2);
GGML_ASSERT(k->ne[3] == ne3);
GGML_ASSERT(v->ne[2] == ne2);
GGML_ASSERT(v->ne[2] == kvne2);
GGML_ASSERT(v->ne[3] == ne3);
GGML_ASSERT(d->ne[2] == ne2);
GGML_ASSERT(d->ne[3] == ne3);
GGML_ASSERT(ne2 % kvne2 == 0);
bool is_node = false;
if (q->grad || k->grad || v->grad) {
@ -7685,14 +7688,23 @@ struct ggml_tensor * ggml_flash_attn_back(
}
// store gradients of q, k and v as continuous tensors concatenated in result.
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
// gradq->data = result->data
// gradk->data = result->data + nb0*D*N*ne2*ne3
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
int64_t ne[4] = {D,M+N+M,ne2,ne3};
const int64_t elem_q = ggml_nelements(q);
const int64_t elem_k = ggml_nelements(k);
const int64_t elem_v = ggml_nelements(v);
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
enum ggml_type result_type = GGML_TYPE_F32;
GGML_ASSERT(ggml_blck_size(result_type) == 1);
const size_t tsize = ggml_type_size(result_type);
const size_t offs_q = 0;
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
const size_t nelements = (end + tsize - 1)/tsize;
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
int32_t masked_i = masked ? 1 : 0;
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
@ -14425,7 +14437,7 @@ static void ggml_compute_forward_flash_attn_f32(
for (int64_t ic = 0; ic < nek1; ++ic) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2;
const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
@ -14449,6 +14461,8 @@ static void ggml_compute_forward_flash_attn_f32(
}
// softmax
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
// dont forget to set their SW values to zero
{
float max = -INFINITY;
ggml_vec_max_f32(M, &max, S);
@ -14509,9 +14523,13 @@ static void ggml_compute_forward_flash_attn_f32(
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_dot_f32(nek1,
// v indices
const int iv2 = iq2 % nev2;
const int iv3 = iq3;
ggml_vec_dot_f32(nev0,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S);
}
}
@ -14608,7 +14626,7 @@ static void ggml_compute_forward_flash_attn_f16(
for (int64_t ic = 0; ic < nek1; ++ic) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2;
const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
@ -14623,7 +14641,7 @@ static void ggml_compute_forward_flash_attn_f16(
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2;
const int ik2 = iq2 % nek2;
const int ik1 = ic;
// S indices
@ -14648,6 +14666,8 @@ static void ggml_compute_forward_flash_attn_f16(
}
// softmax
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
// dont forget to set their S values to zero
{
float max = -INFINITY;
ggml_vec_max_f32(M, &max, S);
@ -14704,6 +14724,7 @@ static void ggml_compute_forward_flash_attn_f16(
S16[i] = GGML_FP32_TO_FP16(S[i]);
}
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
for (int64_t ic = 0; ic < nev1; ++ic) {
// dst indices
@ -14711,9 +14732,13 @@ static void ggml_compute_forward_flash_attn_f16(
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_dot_f16(nek1,
// v indices
const int iv2 = iq2 % nev2;
const int iv3 = iq3;
ggml_vec_dot_f16(nev0,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
} else {
@ -14723,9 +14748,13 @@ static void ggml_compute_forward_flash_attn_f16(
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_dot_f16_unroll(nek1, nbv1,
// v indices
const int iv2 = iq2 % nev2;
const int iv3 = iq3;
ggml_vec_dot_f16_unroll(nev0, nbv1,
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
S16);
}
}
@ -14984,10 +15013,38 @@ static void ggml_compute_forward_flash_attn_back_f32(
return;
}
// parallelize by q rows using ggml_vec_dot_f32
const int64_t elem_q = ggml_nelements(q);
const int64_t elem_k = ggml_nelements(k);
const int64_t elem_v = ggml_nelements(v);
// total rows in q
const int nr = neq2*neq3;
enum ggml_type result_type = dst->type;
GGML_ASSERT(ggml_blck_size(result_type) == 1);
const size_t tsize = ggml_type_size(result_type);
const size_t offs_q = 0;
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
void * grad_q = (char *) dst->data;
void * grad_k = (char *) dst->data + offs_k;
void * grad_v = (char *) dst->data + offs_v;
const size_t nbgq1 = nb0*neq0;
const size_t nbgq2 = nb0*neq0*neq1;
const size_t nbgq3 = nb0*neq0*neq1*neq2;
const size_t nbgk1 = nb0*nek0;
const size_t nbgk2 = nb0*nek0*nek1;
const size_t nbgk3 = nb0*nek0*nek1*neq2;
const size_t nbgv1 = nb0*nev0;
const size_t nbgv2 = nb0*nev0*nev1;
const size_t nbgv3 = nb0*nev0*nev1*neq2;
// parallelize by k rows using ggml_vec_dot_f32
// total rows in k
const int nr = nek2*nek3;
// rows per thread
const int dr = (nr + nth - 1)/nth;
@ -15000,12 +15057,26 @@ static void ggml_compute_forward_flash_attn_back_f32(
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
// how often k2 (and v2) is repeated in q2
int nrep = neq2/nek2;
for (int ir = ir0; ir < ir1; ++ir) {
// q indices
const int iq3 = ir/(neq2);
const int iq2 = ir - iq3*neq2;
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
const int ik3 = ir/(nek2);
const int ik2 = ir - ik3*nek2;
const int iq3 = ik3;
const int id3 = ik3;
const int iv3 = ik3;
const int iv2 = ik2;
for (int irep = 0; irep < nrep; ++irep) {
const int iq2 = ik2 + irep*nek2;
const int id2 = iq2;
// (ik2 + irep*nek2) % nek2 == ik2
for (int iq1 = 0; iq1 < neq1; ++iq1) {
const int id1 = iq1;
// not sure about CACHE_LINE_SIZE_F32..
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
@ -15018,8 +15089,6 @@ static void ggml_compute_forward_flash_attn_back_f32(
for (int64_t ic = 0; ic < nek1; ++ic) {
// k indices
const int ik3 = iq3;
const int ik2 = iq2;
const int ik1 = ic;
// S indices
@ -15043,6 +15112,8 @@ static void ggml_compute_forward_flash_attn_back_f32(
}
// softmax
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
// dont forget to set their SM values to zero
{
float max = -INFINITY;
ggml_vec_max_f32(M, &max, S);
@ -15095,14 +15166,16 @@ static void ggml_compute_forward_flash_attn_back_f32(
// step-by-step explanation
{
// forward-process shape grads from backward process
// parallel_for iq2,iq3:
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
// parallel_for ik2,ik3:
// for irep:
// iq2 = ik2 + irep*nek2
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
// for iq1:
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
// kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
// vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
// S0 = -Inf [D,1,1,1]
// ~S1[i] = dot(kcur[:D,i], qcur)
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
@ -15110,9 +15183,9 @@ static void ggml_compute_forward_flash_attn_back_f32(
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
// ~S5[i] = dot(vcur[:,i], S4)
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
// dst backward-/ grad[dst] = d
//
// output gradients with their dependencies:
@ -15121,10 +15194,10 @@ static void ggml_compute_forward_flash_attn_back_f32(
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
// grad[S4] = grad[S5] @ vcur
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
// grad[S4] = d[:D,id1,id2,id3] @ vcur
// grad[qcur] = grad[S1] @ kcur
// grad[vcur] = grad[S5].T @ S4
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
// grad[vcur] = d[:D,id1,id2,id3].T @ S4
//
// in post-order:
//
@ -15132,12 +15205,12 @@ static void ggml_compute_forward_flash_attn_back_f32(
// S2 = S1 * scale
// S3 = diag_mask_inf(S2, P)
// S4 = softmax(S3)
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
// grad[S4] = d[:D,id1,id2,id3] @ vcur
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
// grad[qcur] = grad[S1] @ kcur
// grad[kcur] = grad[S1].T @ qcur
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
// grad[vcur] = d[:D,id1,id2,id3].T @ S4
//
// using less variables (SM=S4):
//
@ -15149,24 +15222,20 @@ static void ggml_compute_forward_flash_attn_back_f32(
// S = diag_mask_zero(S, P) * scale
//
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
// grad[k][:D,:M,ik2,ik3] += S.T @ qcur
// grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
}
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
// S = d[:D,iq1,iq2,iq3] @ vcur
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
// S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
// S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
// for ic:
// S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
ggml_vec_set_f32(M, S, 0);
for (int64_t ic = 0; ic < D; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_mad_f32(M,
S,
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
}
// S = SM * (S - dot(SM, S))
@ -15186,82 +15255,50 @@ static void ggml_compute_forward_flash_attn_back_f32(
}
}
}
// todo: exclude known zero S[..] values from operation
ggml_vec_scale_f32(M, S, scale);
void * grad_q = (char *) dst->data;
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
const size_t nbgq1 = nb0*neq0;
const size_t nbgq2 = nb0*neq0*neq1;
const size_t nbgq3 = nb0*neq0*neq1*neq2;
const size_t nbgk1 = nb0*nek0;
const size_t nbgk2 = nb0*nek0*nek1;
const size_t nbgk3 = nb0*nek0*nek1*neq2;
const size_t nbgv1 = nb0*nev0;
const size_t nbgv2 = nb0*nev0*nev1;
const size_t nbgv3 = nb0*nev0*nev1*neq2;
// S shape [M,1]
// SM shape [M,1]
// kcur shape [D,M]
// qcur shape [D,1]
// vcur shape [M,D]
//
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
//
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
// for ic:
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
// todo: exclude known zero S[..] values from loop
for (int64_t ic = 0; ic < M; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
ggml_vec_mad_f32(D,
(float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
(float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
(float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
(float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
S[ic]);
}
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
// for ic:
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
// todo: exclude known zero S[..] values from loop
for (int64_t ic = 0; ic < M; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// ggml_vec_set_f32(D,
// (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
// 0);
ggml_vec_mad_f32(D,
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
(float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
(float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
S[ic]);
}
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
// grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
// for ic:
// grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
// grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
// todo: exclude known zero SM[..] values from mad
for (int64_t ic = 0; ic < D; ++ic) {
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// ggml_vec_set_f32(M,
// (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
// 0);
ggml_vec_mad_f32(M,
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
(float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
SM,
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
}
}
}
}
@ -17166,143 +17203,40 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
masked);
}
if (src0->grad) {
struct ggml_tensor * grad_q = NULL;
const size_t nb0 = flash_grad->nb[0];
const size_t offset = 0;
switch(src0->n_dims) {
case 2:
{
grad_q = ggml_view_2d(ctx,
flash_grad,
src0->ne[0],
src0->ne[1],
nb0*src0->ne[0],
offset);
} break;
case 3:
{
grad_q = ggml_view_3d(ctx,
flash_grad,
src0->ne[0],
src0->ne[1],
src0->ne[2],
nb0*src0->ne[0],
nb0*src0->ne[0]*src0->ne[1],
offset);
} break;
case 4:
{
grad_q = ggml_view_4d(ctx,
flash_grad,
src0->ne[0],
src0->ne[1],
src0->ne[2],
src0->ne[3],
nb0*src0->ne[0],
nb0*src0->ne[0]*src0->ne[1],
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
offset);
} break;
}
struct ggml_tensor * src2 = tensor->src[2];
const int64_t elem_q = ggml_nelements(src0);
const int64_t elem_k = ggml_nelements(src1);
const int64_t elem_v = ggml_nelements(src2);
enum ggml_type result_type = flash_grad->type;
GGML_ASSERT(ggml_blck_size(result_type) == 1);
const size_t tsize = ggml_type_size(result_type);
const size_t offs_q = 0;
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
if (src0->grad) {
struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
src0->grad = ggml_add_or_set(ctx,
src0->grad,
grad_q,
zero_table);
}
if (src1->grad) {
struct ggml_tensor * grad_k = NULL;
const size_t nb0 = flash_grad->nb[0];
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
switch(src1->n_dims) {
case 2:
{
grad_k = ggml_view_2d(ctx,
flash_grad,
src1->ne[0],
src1->ne[1],
nb0*src1->ne[0],
offset);
} break;
case 3:
{
grad_k = ggml_view_3d(ctx,
flash_grad,
src1->ne[0],
src1->ne[1],
src1->ne[2],
nb0*src1->ne[0],
nb0*src1->ne[0]*src1->ne[1],
offset);
} break;
case 4:
{
grad_k = ggml_view_4d(ctx,
flash_grad,
src1->ne[0],
src1->ne[1],
src1->ne[2],
src1->ne[3],
nb0*src1->ne[0],
nb0*src1->ne[0]*src1->ne[1],
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
offset);
} break;
}
struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
src1->grad = ggml_add_or_set(ctx,
src1->grad,
grad_k,
zero_table);
}
struct ggml_tensor * opt0 = tensor->src[2];
if (opt0->grad) {
struct ggml_tensor * grad_v = NULL;
const size_t nb0 = flash_grad->nb[0];
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
switch(opt0->n_dims) {
case 2:
{
grad_v = ggml_view_2d(ctx,
flash_grad,
opt0->ne[0],
opt0->ne[1],
nb0*opt0->ne[0],
offset);
} break;
case 3:
{
grad_v = ggml_view_3d(ctx,
flash_grad,
opt0->ne[0],
opt0->ne[1],
opt0->ne[2],
nb0*opt0->ne[0],
nb0*opt0->ne[0]*opt0->ne[1],
offset);
} break;
case 4:
{
grad_v = ggml_view_4d(ctx,
flash_grad,
opt0->ne[0],
opt0->ne[1],
opt0->ne[2],
opt0->ne[3],
nb0*opt0->ne[0],
nb0*opt0->ne[0]*opt0->ne[1],
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
offset);
} break;
}
opt0->grad = ggml_add_or_set(ctx,
opt0->grad,
if (src2->grad) {
struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
src2->grad = ggml_add_or_set(ctx,
src2->grad,
grad_v,
zero_table);
}

View file

@ -1359,6 +1359,10 @@ int main(int argc, const char ** argv) {
ggml_new_f32(ctx0, eps))));
check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
// NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf.
// this may result in different gradients too finite differences.
// when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause.
// if only the table lookup causes gradients to differ this is acceptable.
}
}
@ -1474,7 +1478,9 @@ int main(int argc, const char ** argv) {
for (int masked = 0; masked <= 1; ++masked) {
for (int ndims = 2; ndims <= 4; ++ndims) {
int64_t neq[4] = { D, N, B, ne[3] };
int max_nrep = (ndims >= 3) ? 2 : 1;
for (int nrep = 1; nrep < max_nrep; ++nrep) {
int64_t neq[4] = { D, N, B*nrep, ne[3] };
int64_t nek[4] = { D, M, B, ne[3] };
int64_t nev[4] = { M, D, B, ne[3] };
if (ndims == 2) {
@ -1499,6 +1505,7 @@ int main(int argc, const char ** argv) {
}
}
}
}
// flash_attn f16, not yet fully implemented
if(0)