support grouped-query-attention in ggml_flash_attn and ggml_flash_attn_back
k and v can now be repeated in q along ne[2] in forward pass just use modulo to compute k and v indices, like ik2 = iq2 % nek2. in backard pass this won't work as easy, because multiple threads will compete to accumulate to the same k->grad[:,ik1,ik2,ik3] and v->grad[:,iv1,iv2,iv3]. so we change the parallelization over q rows to be over k rows. this ensures non-overlapping (ik2,ik3) across threads. in each thread we then iterate over the number of repetitions of k/v in q to compute iq2 as iq2 = ik2 + irep*nek2. since ne2 is not the same for q,k and v we also change how the gradients are concatenated into the result tensor. additionally the offsets of gradq, gradk and gradv in the result tensor are now memory aligned. we also simplify the compute_backward part of flash_attn to use ggml_reshape instead of switching over the number of dimensions. this needs a small change to ggml_reshape, removing the assertion of second argument to be contiguous. since only the shape (ne) of the second reshape argument is of relevance, its memory layout (nb) is irrelevant -> it can very well be non-contiguous. change test-grad0 to also test for repeated k/v in q. this changes the rng and now results in small gradient differences in softmax. these solely come from using f16 exp table lookup in forward softmax: when temporarily changing softmax to use actual exp function, the reported gradient differences go away. gradient differences coming solely from f16 table lookup are acceptable. added a note to explain this.
This commit is contained in:
parent
0c2c9c7545
commit
d7aade7d8a
2 changed files with 353 additions and 412 deletions
716
ggml.c
716
ggml.c
|
@ -6585,7 +6585,7 @@ struct ggml_tensor * ggml_reshape(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
struct ggml_tensor * b) {
|
struct ggml_tensor * b) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(a));
|
GGML_ASSERT(ggml_is_contiguous(a));
|
||||||
GGML_ASSERT(ggml_is_contiguous(b));
|
// as only the shape of b is relevant, and not its memory layout, b is allowed to be non contiguous.
|
||||||
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
|
GGML_ASSERT(ggml_nelements(a) == ggml_nelements(b));
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
@ -7655,27 +7655,30 @@ struct ggml_tensor * ggml_flash_attn_back(
|
||||||
|
|
||||||
// d shape [D,N,ne2,ne3]
|
// d shape [D,N,ne2,ne3]
|
||||||
// q shape [D,N,ne2,ne3]
|
// q shape [D,N,ne2,ne3]
|
||||||
// k shape [D,M,ne2,ne3]
|
// k shape [D,M,kvne2,ne3]
|
||||||
// v shape [M,D,ne2,ne3]
|
// v shape [M,D,kvne2,ne3]
|
||||||
|
|
||||||
const int64_t D = q->ne[0];
|
const int64_t D = q->ne[0];
|
||||||
const int64_t N = q->ne[1];
|
const int64_t N = q->ne[1];
|
||||||
const int64_t M = k->ne[1];
|
const int64_t M = k->ne[1];
|
||||||
const int64_t ne2 = q->ne[2];
|
const int64_t ne2 = q->ne[2];
|
||||||
const int64_t ne3 = q->ne[3];
|
const int64_t ne3 = q->ne[3];
|
||||||
|
const int64_t kvne2 = k->ne[2];
|
||||||
|
|
||||||
GGML_ASSERT(k->ne[0] == D);
|
GGML_ASSERT(k->ne[0] == D);
|
||||||
GGML_ASSERT(v->ne[0] == M);
|
GGML_ASSERT(v->ne[0] == M);
|
||||||
GGML_ASSERT(v->ne[1] == D);
|
GGML_ASSERT(v->ne[1] == D);
|
||||||
GGML_ASSERT(d->ne[0] == D);
|
GGML_ASSERT(d->ne[0] == D);
|
||||||
GGML_ASSERT(d->ne[1] == N);
|
GGML_ASSERT(d->ne[1] == N);
|
||||||
GGML_ASSERT(k->ne[2] == ne2);
|
GGML_ASSERT(k->ne[2] == kvne2);
|
||||||
GGML_ASSERT(k->ne[3] == ne3);
|
GGML_ASSERT(k->ne[3] == ne3);
|
||||||
GGML_ASSERT(v->ne[2] == ne2);
|
GGML_ASSERT(v->ne[2] == kvne2);
|
||||||
GGML_ASSERT(v->ne[3] == ne3);
|
GGML_ASSERT(v->ne[3] == ne3);
|
||||||
GGML_ASSERT(d->ne[2] == ne2);
|
GGML_ASSERT(d->ne[2] == ne2);
|
||||||
GGML_ASSERT(d->ne[3] == ne3);
|
GGML_ASSERT(d->ne[3] == ne3);
|
||||||
|
|
||||||
|
GGML_ASSERT(ne2 % kvne2 == 0);
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (q->grad || k->grad || v->grad) {
|
if (q->grad || k->grad || v->grad) {
|
||||||
|
@ -7685,14 +7688,23 @@ struct ggml_tensor * ggml_flash_attn_back(
|
||||||
}
|
}
|
||||||
|
|
||||||
// store gradients of q, k and v as continuous tensors concatenated in result.
|
// store gradients of q, k and v as continuous tensors concatenated in result.
|
||||||
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
|
|
||||||
// gradq->data = result->data
|
|
||||||
// gradk->data = result->data + nb0*D*N*ne2*ne3
|
|
||||||
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
|
|
||||||
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
||||||
int64_t ne[4] = {D,M+N+M,ne2,ne3};
|
const int64_t elem_q = ggml_nelements(q);
|
||||||
|
const int64_t elem_k = ggml_nelements(k);
|
||||||
|
const int64_t elem_v = ggml_nelements(v);
|
||||||
|
|
||||||
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
enum ggml_type result_type = GGML_TYPE_F32;
|
||||||
|
GGML_ASSERT(ggml_blck_size(result_type) == 1);
|
||||||
|
const size_t tsize = ggml_type_size(result_type);
|
||||||
|
|
||||||
|
const size_t offs_q = 0;
|
||||||
|
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
|
||||||
|
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
|
||||||
|
const size_t end = offs_v + GGML_PAD(elem_v * tsize, GGML_MEM_ALIGN);
|
||||||
|
|
||||||
|
const size_t nelements = (end + tsize - 1)/tsize;
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nelements);
|
||||||
|
|
||||||
int32_t masked_i = masked ? 1 : 0;
|
int32_t masked_i = masked ? 1 : 0;
|
||||||
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
|
ggml_set_op_params(result, &masked_i, sizeof(masked_i));
|
||||||
|
@ -14425,7 +14437,7 @@ static void ggml_compute_forward_flash_attn_f32(
|
||||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||||
// k indices
|
// k indices
|
||||||
const int ik3 = iq3;
|
const int ik3 = iq3;
|
||||||
const int ik2 = iq2;
|
const int ik2 = iq2 % nek2;
|
||||||
const int ik1 = ic;
|
const int ik1 = ic;
|
||||||
|
|
||||||
// S indices
|
// S indices
|
||||||
|
@ -14449,6 +14461,8 @@ static void ggml_compute_forward_flash_attn_f32(
|
||||||
}
|
}
|
||||||
|
|
||||||
// softmax
|
// softmax
|
||||||
|
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
||||||
|
// dont forget to set their SW values to zero
|
||||||
{
|
{
|
||||||
float max = -INFINITY;
|
float max = -INFINITY;
|
||||||
ggml_vec_max_f32(M, &max, S);
|
ggml_vec_max_f32(M, &max, S);
|
||||||
|
@ -14509,9 +14523,13 @@ static void ggml_compute_forward_flash_attn_f32(
|
||||||
const int i2 = iq2;
|
const int i2 = iq2;
|
||||||
const int i3 = iq3;
|
const int i3 = iq3;
|
||||||
|
|
||||||
ggml_vec_dot_f32(nek1,
|
// v indices
|
||||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
const int iv2 = iq2 % nev2;
|
||||||
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
const int iv3 = iq3;
|
||||||
|
|
||||||
|
ggml_vec_dot_f32(nev0,
|
||||||
|
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||||
|
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||||
S);
|
S);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14608,7 +14626,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
||||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||||
// k indices
|
// k indices
|
||||||
const int ik3 = iq3;
|
const int ik3 = iq3;
|
||||||
const int ik2 = iq2;
|
const int ik2 = iq2 % nek2;
|
||||||
const int ik1 = ic;
|
const int ik1 = ic;
|
||||||
|
|
||||||
// S indices
|
// S indices
|
||||||
|
@ -14623,7 +14641,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
||||||
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
||||||
// k indices
|
// k indices
|
||||||
const int ik3 = iq3;
|
const int ik3 = iq3;
|
||||||
const int ik2 = iq2;
|
const int ik2 = iq2 % nek2;
|
||||||
const int ik1 = ic;
|
const int ik1 = ic;
|
||||||
|
|
||||||
// S indices
|
// S indices
|
||||||
|
@ -14648,6 +14666,8 @@ static void ggml_compute_forward_flash_attn_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
// softmax
|
// softmax
|
||||||
|
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
||||||
|
// dont forget to set their S values to zero
|
||||||
{
|
{
|
||||||
float max = -INFINITY;
|
float max = -INFINITY;
|
||||||
ggml_vec_max_f32(M, &max, S);
|
ggml_vec_max_f32(M, &max, S);
|
||||||
|
@ -14704,6 +14724,7 @@ static void ggml_compute_forward_flash_attn_f16(
|
||||||
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
S16[i] = GGML_FP32_TO_FP16(S[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo: exclude known zero S[..] values from dot (reducing nev0 and increasing begin of v and S16).
|
||||||
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
if (GGML_VEC_DOT_UNROLL == 1 || (nev1 % GGML_VEC_DOT_UNROLL != 0)) {
|
||||||
for (int64_t ic = 0; ic < nev1; ++ic) {
|
for (int64_t ic = 0; ic < nev1; ++ic) {
|
||||||
// dst indices
|
// dst indices
|
||||||
|
@ -14711,9 +14732,13 @@ static void ggml_compute_forward_flash_attn_f16(
|
||||||
const int i2 = iq2;
|
const int i2 = iq2;
|
||||||
const int i3 = iq3;
|
const int i3 = iq3;
|
||||||
|
|
||||||
ggml_vec_dot_f16(nek1,
|
// v indices
|
||||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
const int iv2 = iq2 % nev2;
|
||||||
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
const int iv3 = iq3;
|
||||||
|
|
||||||
|
ggml_vec_dot_f16(nev0,
|
||||||
|
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||||
|
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||||
S16);
|
S16);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -14723,9 +14748,13 @@ static void ggml_compute_forward_flash_attn_f16(
|
||||||
const int i2 = iq2;
|
const int i2 = iq2;
|
||||||
const int i3 = iq3;
|
const int i3 = iq3;
|
||||||
|
|
||||||
ggml_vec_dot_f16_unroll(nek1, nbv1,
|
// v indices
|
||||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
const int iv2 = iq2 % nev2;
|
||||||
((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
const int iv3 = iq3;
|
||||||
|
|
||||||
|
ggml_vec_dot_f16_unroll(nev0, nbv1,
|
||||||
|
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||||
|
((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||||
S16);
|
S16);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -14984,10 +15013,38 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// parallelize by q rows using ggml_vec_dot_f32
|
const int64_t elem_q = ggml_nelements(q);
|
||||||
|
const int64_t elem_k = ggml_nelements(k);
|
||||||
|
const int64_t elem_v = ggml_nelements(v);
|
||||||
|
|
||||||
// total rows in q
|
enum ggml_type result_type = dst->type;
|
||||||
const int nr = neq2*neq3;
|
GGML_ASSERT(ggml_blck_size(result_type) == 1);
|
||||||
|
const size_t tsize = ggml_type_size(result_type);
|
||||||
|
|
||||||
|
const size_t offs_q = 0;
|
||||||
|
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
|
||||||
|
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
|
||||||
|
|
||||||
|
void * grad_q = (char *) dst->data;
|
||||||
|
void * grad_k = (char *) dst->data + offs_k;
|
||||||
|
void * grad_v = (char *) dst->data + offs_v;
|
||||||
|
|
||||||
|
const size_t nbgq1 = nb0*neq0;
|
||||||
|
const size_t nbgq2 = nb0*neq0*neq1;
|
||||||
|
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
||||||
|
|
||||||
|
const size_t nbgk1 = nb0*nek0;
|
||||||
|
const size_t nbgk2 = nb0*nek0*nek1;
|
||||||
|
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
||||||
|
|
||||||
|
const size_t nbgv1 = nb0*nev0;
|
||||||
|
const size_t nbgv2 = nb0*nev0*nev1;
|
||||||
|
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
||||||
|
|
||||||
|
// parallelize by k rows using ggml_vec_dot_f32
|
||||||
|
|
||||||
|
// total rows in k
|
||||||
|
const int nr = nek2*nek3;
|
||||||
|
|
||||||
// rows per thread
|
// rows per thread
|
||||||
const int dr = (nr + nth - 1)/nth;
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
@ -15000,268 +15057,248 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
|
|
||||||
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
||||||
|
|
||||||
|
// how often k2 (and v2) is repeated in q2
|
||||||
|
int nrep = neq2/nek2;
|
||||||
|
|
||||||
for (int ir = ir0; ir < ir1; ++ir) {
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
// q indices
|
// q indices
|
||||||
const int iq3 = ir/(neq2);
|
const int ik3 = ir/(nek2);
|
||||||
const int iq2 = ir - iq3*neq2;
|
const int ik2 = ir - ik3*nek2;
|
||||||
for ( int iq1 = 0; iq1 < neq1; ++iq1) {
|
|
||||||
|
|
||||||
|
const int iq3 = ik3;
|
||||||
|
const int id3 = ik3;
|
||||||
|
const int iv3 = ik3;
|
||||||
|
const int iv2 = ik2;
|
||||||
|
|
||||||
// not sure about CACHE_LINE_SIZE_F32..
|
for (int irep = 0; irep < nrep; ++irep) {
|
||||||
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
const int iq2 = ik2 + irep*nek2;
|
||||||
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
const int id2 = iq2;
|
||||||
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
|
||||||
|
|
||||||
for (int i = M; i < Mup; ++i) {
|
// (ik2 + irep*nek2) % nek2 == ik2
|
||||||
S[i] = -INFINITY;
|
for (int iq1 = 0; iq1 < neq1; ++iq1) {
|
||||||
}
|
const int id1 = iq1;
|
||||||
|
|
||||||
for (int64_t ic = 0; ic < nek1; ++ic) {
|
// not sure about CACHE_LINE_SIZE_F32..
|
||||||
// k indices
|
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
||||||
const int ik3 = iq3;
|
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
||||||
const int ik2 = iq2;
|
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
||||||
const int ik1 = ic;
|
|
||||||
|
|
||||||
// S indices
|
for (int i = M; i < Mup; ++i) {
|
||||||
const int i1 = ik1;
|
S[i] = -INFINITY;
|
||||||
|
|
||||||
ggml_vec_dot_f32(neq0,
|
|
||||||
S + i1,
|
|
||||||
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
|
||||||
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// scale
|
|
||||||
ggml_vec_scale_f32(nek1, S, scale);
|
|
||||||
|
|
||||||
if (masked) {
|
|
||||||
for (int64_t i = P; i < M; i++) {
|
|
||||||
if (i > P + iq1) {
|
|
||||||
S[i] = -INFINITY;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// softmax
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||||
{
|
// k indices
|
||||||
float max = -INFINITY;
|
const int ik1 = ic;
|
||||||
ggml_vec_max_f32(M, &max, S);
|
|
||||||
|
|
||||||
ggml_float sum = 0.0;
|
// S indices
|
||||||
{
|
const int i1 = ik1;
|
||||||
#ifdef GGML_SOFT_MAX_ACCELERATE
|
|
||||||
max = -max;
|
|
||||||
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
|
||||||
vvexpf(SM, SM, &Mup);
|
|
||||||
ggml_vec_sum_f32(Mup, &sum, SM);
|
|
||||||
#else
|
|
||||||
uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
|
|
||||||
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
|
||||||
|
|
||||||
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
ggml_vec_dot_f32(neq0,
|
||||||
float * SR = S + i;
|
S + i1,
|
||||||
float * SW = SM + i;
|
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||||
|
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
||||||
|
}
|
||||||
|
|
||||||
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
// scale
|
||||||
if (SR[j] == -INFINITY) {
|
ggml_vec_scale_f32(nek1, S, scale);
|
||||||
SW[j] = 0.0f;
|
|
||||||
} else {
|
if (masked) {
|
||||||
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
for (int64_t i = P; i < M; i++) {
|
||||||
const float val = expf(SR[j] - max);
|
if (i > P + iq1) {
|
||||||
#else
|
S[i] = -INFINITY;
|
||||||
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
|
||||||
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
|
||||||
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
|
||||||
#endif
|
|
||||||
sump[j] += (ggml_float)val;
|
|
||||||
SW[j] = val;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
// softmax
|
||||||
sum += sump[i];
|
// todo: exclude known -INF S[..] values from max and loop, assuming their results to be zero.
|
||||||
}
|
// dont forget to set their SM values to zero
|
||||||
|
{
|
||||||
|
float max = -INFINITY;
|
||||||
|
ggml_vec_max_f32(M, &max, S);
|
||||||
|
|
||||||
|
ggml_float sum = 0.0;
|
||||||
|
{
|
||||||
|
#ifdef GGML_SOFT_MAX_ACCELERATE
|
||||||
|
max = -max;
|
||||||
|
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
||||||
|
vvexpf(SM, SM, &Mup);
|
||||||
|
ggml_vec_sum_f32(Mup, &sum, SM);
|
||||||
|
#else
|
||||||
|
uint16_t scvt[GGML_SOFT_MAX_UNROLL]; UNUSED(scvt);
|
||||||
|
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
||||||
|
|
||||||
|
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
||||||
|
float * SR = S + i;
|
||||||
|
float * SW = SM + i;
|
||||||
|
|
||||||
|
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
||||||
|
if (SR[j] == -INFINITY) {
|
||||||
|
SW[j] = 0.0f;
|
||||||
|
} else {
|
||||||
|
#ifndef GGML_FLASH_ATTN_EXP_FP16
|
||||||
|
const float val = expf(SR[j] - max);
|
||||||
|
#else
|
||||||
|
ggml_fp16_t s = GGML_FP32_TO_FP16(SR[j] - max);
|
||||||
|
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||||
|
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
||||||
#endif
|
#endif
|
||||||
|
sump[j] += (ggml_float)val;
|
||||||
|
SW[j] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
||||||
|
sum += sump[i];
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(sum > 0.0);
|
||||||
|
|
||||||
|
sum = 1.0/sum;
|
||||||
|
ggml_vec_scale_f32(M, SM, sum);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
assert(sum > 0.0);
|
// step-by-step explanation
|
||||||
|
{
|
||||||
|
// forward-process shape grads from backward process
|
||||||
|
// parallel_for ik2,ik3:
|
||||||
|
// for irep:
|
||||||
|
// iq2 = ik2 + irep*nek2
|
||||||
|
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,ik2,ik3] += grad[kcur]
|
||||||
|
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
||||||
|
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iv2,iv3] += grad[vcur]
|
||||||
|
// for iq1:
|
||||||
|
// kcur = k[:D,:M,ik2,ik3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
||||||
|
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
||||||
|
// vcur = v[:M,:D,iv2,iv3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
||||||
|
// S0 = -Inf [D,1,1,1]
|
||||||
|
// ~S1[i] = dot(kcur[:D,i], qcur)
|
||||||
|
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
||||||
|
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
||||||
|
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
||||||
|
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
||||||
|
// ~S5[i] = dot(vcur[:,i], S4)
|
||||||
|
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,id1,id2,id3]
|
||||||
|
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
||||||
|
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,id1,id2,id3]
|
||||||
|
// dst backward-/ grad[dst] = d
|
||||||
|
//
|
||||||
|
// output gradients with their dependencies:
|
||||||
|
//
|
||||||
|
// grad[kcur] = grad[S1].T @ qcur
|
||||||
|
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
||||||
|
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
||||||
|
// grad[S4] = grad[S5] @ vcur
|
||||||
|
// grad[S4] = d[:D,id1,id2,id3] @ vcur
|
||||||
|
// grad[qcur] = grad[S1] @ kcur
|
||||||
|
// grad[vcur] = grad[S5].T @ S4
|
||||||
|
// grad[vcur] = d[:D,id1,id2,id3].T @ S4
|
||||||
|
//
|
||||||
|
// in post-order:
|
||||||
|
//
|
||||||
|
// S1 = qcur @ kcur.T
|
||||||
|
// S2 = S1 * scale
|
||||||
|
// S3 = diag_mask_inf(S2, P)
|
||||||
|
// S4 = softmax(S3)
|
||||||
|
// grad[S4] = d[:D,id1,id2,id3] @ vcur
|
||||||
|
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
||||||
|
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
||||||
|
// grad[qcur] = grad[S1] @ kcur
|
||||||
|
// grad[kcur] = grad[S1].T @ qcur
|
||||||
|
// grad[vcur] = d[:D,id1,id2,id3].T @ S4
|
||||||
|
//
|
||||||
|
// using less variables (SM=S4):
|
||||||
|
//
|
||||||
|
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
||||||
|
// SM = softmax(S)
|
||||||
|
// S = d[:D,iq1,iq2,iq3] @ vcur
|
||||||
|
// dot_SM_gradSM = dot(SM, S)
|
||||||
|
// S = SM * (S - dot(SM, S))
|
||||||
|
// S = diag_mask_zero(S, P) * scale
|
||||||
|
//
|
||||||
|
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
||||||
|
// grad[k][:D,:M,ik2,ik3] += S.T @ qcur
|
||||||
|
// grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
|
||||||
|
}
|
||||||
|
|
||||||
sum = 1.0/sum;
|
// S = gradSM = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
|
||||||
ggml_vec_scale_f32(M, SM, sum);
|
// S = d[:D,id1,id2,id3] @ vcur[:,:,iv2,iv3]
|
||||||
|
// for ic:
|
||||||
|
// S[:M] += vcur[:M,ic,iv2,iv3] * d[ic,id1,id2,id3]
|
||||||
|
ggml_vec_set_f32(M, S, 0);
|
||||||
|
for (int64_t ic = 0; ic < D; ++ic) {
|
||||||
|
ggml_vec_mad_f32(M,
|
||||||
|
S,
|
||||||
|
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||||
|
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
|
||||||
|
}
|
||||||
|
|
||||||
}
|
// S = SM * (S - dot(SM, S))
|
||||||
|
float dot_SM_gradSM = 0;
|
||||||
|
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
||||||
|
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
||||||
|
ggml_vec_mul_f32 (M, S, S, SM);
|
||||||
|
|
||||||
// step-by-step explanation
|
// S = diag_mask_zero(S, P) * scale
|
||||||
{
|
if (masked) {
|
||||||
// forward-process shape grads from backward process
|
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
||||||
// parallel_for iq2,iq3:
|
// S[i] = 0;
|
||||||
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
|
// }
|
||||||
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
for (int64_t i = P; i < M; i++) {
|
||||||
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
|
if (i > P + iq1) {
|
||||||
// for iq1:
|
S[i] = 0;
|
||||||
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
}
|
||||||
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
|
||||||
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
|
||||||
// S0 = -Inf [D,1,1,1]
|
|
||||||
// ~S1[i] = dot(kcur[:D,i], qcur)
|
|
||||||
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
|
||||||
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
|
||||||
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
|
||||||
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
|
||||||
// ~S5[i] = dot(vcur[:,i], S4)
|
|
||||||
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
|
||||||
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
|
||||||
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
|
||||||
// dst backward-/ grad[dst] = d
|
|
||||||
//
|
|
||||||
// output gradients with their dependencies:
|
|
||||||
//
|
|
||||||
// grad[kcur] = grad[S1].T @ qcur
|
|
||||||
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
|
||||||
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
|
||||||
// grad[S4] = grad[S5] @ vcur
|
|
||||||
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
|
||||||
// grad[qcur] = grad[S1] @ kcur
|
|
||||||
// grad[vcur] = grad[S5].T @ S4
|
|
||||||
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
|
||||||
//
|
|
||||||
// in post-order:
|
|
||||||
//
|
|
||||||
// S1 = qcur @ kcur.T
|
|
||||||
// S2 = S1 * scale
|
|
||||||
// S3 = diag_mask_inf(S2, P)
|
|
||||||
// S4 = softmax(S3)
|
|
||||||
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
|
||||||
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
|
||||||
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
|
||||||
// grad[qcur] = grad[S1] @ kcur
|
|
||||||
// grad[kcur] = grad[S1].T @ qcur
|
|
||||||
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
|
||||||
//
|
|
||||||
// using less variables (SM=S4):
|
|
||||||
//
|
|
||||||
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
|
||||||
// SM = softmax(S)
|
|
||||||
// S = d[:D,iq1,iq2,iq3] @ vcur
|
|
||||||
// dot_SM_gradSM = dot(SM, S)
|
|
||||||
// S = SM * (S - dot(SM, S))
|
|
||||||
// S = diag_mask_zero(S, P) * scale
|
|
||||||
//
|
|
||||||
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
|
||||||
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
|
||||||
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
|
||||||
}
|
|
||||||
|
|
||||||
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
|
||||||
// S = d[:D,iq1,iq2,iq3] @ vcur
|
|
||||||
// S[:M] += vcur[:M,ic] * d[ic,iq1,iq2,iq3]
|
|
||||||
ggml_vec_set_f32(M, S, 0);
|
|
||||||
for (int64_t ic = 0; ic < D; ++ic) {
|
|
||||||
// dst indices
|
|
||||||
const int i1 = iq1;
|
|
||||||
const int i2 = iq2;
|
|
||||||
const int i3 = iq3;
|
|
||||||
|
|
||||||
ggml_vec_mad_f32(M,
|
|
||||||
S,
|
|
||||||
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
|
||||||
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// S = SM * (S - dot(SM, S))
|
|
||||||
float dot_SM_gradSM = 0;
|
|
||||||
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
|
||||||
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
|
||||||
ggml_vec_mul_f32 (M, S, S, SM);
|
|
||||||
|
|
||||||
// S = diag_mask_zero(S, P) * scale
|
|
||||||
if (masked) {
|
|
||||||
// for (int64_t i = P + iq1 + 1; i < M; i++) {
|
|
||||||
// S[i] = 0;
|
|
||||||
// }
|
|
||||||
for (int64_t i = P; i < M; i++) {
|
|
||||||
if (i > P + iq1) {
|
|
||||||
S[i] = 0;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
// todo: exclude known zero S[..] values from operation
|
||||||
ggml_vec_scale_f32(M, S, scale);
|
ggml_vec_scale_f32(M, S, scale);
|
||||||
|
|
||||||
void * grad_q = (char *) dst->data;
|
// S shape [M,1]
|
||||||
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
|
// SM shape [M,1]
|
||||||
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
|
// kcur shape [D,M]
|
||||||
|
// qcur shape [D,1]
|
||||||
|
// vcur shape [M,D]
|
||||||
|
|
||||||
const size_t nbgq1 = nb0*neq0;
|
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
||||||
const size_t nbgq2 = nb0*neq0*neq1;
|
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
||||||
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
// for ic:
|
||||||
|
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic,ik2,ik3]
|
||||||
|
// todo: exclude known zero S[..] values from loop
|
||||||
|
for (int64_t ic = 0; ic < M; ++ic) {
|
||||||
|
ggml_vec_mad_f32(D,
|
||||||
|
(float *) ((char *) grad_q + (iq1*nbgq1 + iq2*nbgq2 + iq3*nbgq3)),
|
||||||
|
(float *) ((char *) k->data + (ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||||
|
S[ic]);
|
||||||
|
}
|
||||||
|
|
||||||
const size_t nbgk1 = nb0*nek0;
|
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
||||||
const size_t nbgk2 = nb0*nek0*nek1;
|
// for ic:
|
||||||
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
||||||
|
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
||||||
|
// todo: exclude known zero S[..] values from loop
|
||||||
|
for (int64_t ic = 0; ic < M; ++ic) {
|
||||||
|
ggml_vec_mad_f32(D,
|
||||||
|
(float *) ((char *) grad_k + (ic*nbgk1 + ik2*nbgk2 + ik3*nbgk3)),
|
||||||
|
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)),
|
||||||
|
S[ic]);
|
||||||
|
}
|
||||||
|
|
||||||
const size_t nbgv1 = nb0*nev0;
|
// grad[v][:M,:D,iv2,iv3] += d[:D,id1,id2,id3].T @ SM
|
||||||
const size_t nbgv2 = nb0*nev0*nev1;
|
// for ic:
|
||||||
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
// grad[v][:M,ic,iv2,iv3] += d[:D,id1,id2,id3].T[0,ic] * SM[:M]
|
||||||
|
// grad[v][:M,ic,iv2,iv3] += d[ic,id1,id2,id3] * SM[:M]
|
||||||
// S shape [M,1]
|
// todo: exclude known zero SM[..] values from mad
|
||||||
// SM shape [M,1]
|
for (int64_t ic = 0; ic < D; ++ic) {
|
||||||
// kcur shape [D,M]
|
ggml_vec_mad_f32(M,
|
||||||
// qcur shape [D,1]
|
(float *) ((char *) grad_v + ( ic*nbgv1 + iv2*nbgv2 + iv3*nbgv3)),
|
||||||
// vcur shape [M,D]
|
SM,
|
||||||
//
|
*(float *) ((char *) d->data + (ic*nbd0 + id1*nbd1 + id2*nbd2 + id3*nbd3)));
|
||||||
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
}
|
||||||
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
|
||||||
// grad[q][:D,iq1,iq2,iq3] += S[ic] * kcur[:D,ic]
|
|
||||||
//
|
|
||||||
//// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
|
||||||
//// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
|
||||||
for (int64_t ic = 0; ic < M; ++ic) {
|
|
||||||
// dst indices
|
|
||||||
const int i1 = iq1;
|
|
||||||
const int i2 = iq2;
|
|
||||||
const int i3 = iq3;
|
|
||||||
|
|
||||||
ggml_vec_mad_f32(D,
|
|
||||||
(float *) ((char *) grad_q + (i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
|
|
||||||
(float *) ((char *) k->data + (ic*nbk1 + i2*nbk2 + i3*nbk3)),
|
|
||||||
S[ic]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
|
||||||
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
|
||||||
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
|
||||||
for (int64_t ic = 0; ic < M; ++ic) {
|
|
||||||
// dst indices
|
|
||||||
const int i1 = iq1;
|
|
||||||
const int i2 = iq2;
|
|
||||||
const int i3 = iq3;
|
|
||||||
|
|
||||||
// ggml_vec_set_f32(D,
|
|
||||||
// (float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
|
||||||
// 0);
|
|
||||||
ggml_vec_mad_f32(D,
|
|
||||||
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
|
||||||
(float *) ((char *) q->data + (i1*nbq1 + i2*nbq2 + i3*nbq3)),
|
|
||||||
S[ic]);
|
|
||||||
}
|
|
||||||
|
|
||||||
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
|
||||||
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
|
|
||||||
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
|
|
||||||
for (int64_t ic = 0; ic < D; ++ic) {
|
|
||||||
// dst indices
|
|
||||||
const int i1 = iq1;
|
|
||||||
const int i2 = iq2;
|
|
||||||
const int i3 = iq3;
|
|
||||||
|
|
||||||
// ggml_vec_set_f32(M,
|
|
||||||
// (float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
|
||||||
// 0);
|
|
||||||
ggml_vec_mad_f32(M,
|
|
||||||
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
|
||||||
SM,
|
|
||||||
*(float *) ((char *) d->data + (ic*nbd0 + i1*nbd1 + i2*nbd2 + i3*nbd3)));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -17166,143 +17203,40 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
masked);
|
masked);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src0->grad) {
|
struct ggml_tensor * src2 = tensor->src[2];
|
||||||
struct ggml_tensor * grad_q = NULL;
|
const int64_t elem_q = ggml_nelements(src0);
|
||||||
const size_t nb0 = flash_grad->nb[0];
|
const int64_t elem_k = ggml_nelements(src1);
|
||||||
const size_t offset = 0;
|
const int64_t elem_v = ggml_nelements(src2);
|
||||||
switch(src0->n_dims) {
|
|
||||||
case 2:
|
|
||||||
{
|
|
||||||
grad_q = ggml_view_2d(ctx,
|
|
||||||
flash_grad,
|
|
||||||
src0->ne[0],
|
|
||||||
src0->ne[1],
|
|
||||||
nb0*src0->ne[0],
|
|
||||||
offset);
|
|
||||||
} break;
|
|
||||||
case 3:
|
|
||||||
{
|
|
||||||
grad_q = ggml_view_3d(ctx,
|
|
||||||
flash_grad,
|
|
||||||
src0->ne[0],
|
|
||||||
src0->ne[1],
|
|
||||||
src0->ne[2],
|
|
||||||
nb0*src0->ne[0],
|
|
||||||
nb0*src0->ne[0]*src0->ne[1],
|
|
||||||
offset);
|
|
||||||
} break;
|
|
||||||
case 4:
|
|
||||||
{
|
|
||||||
grad_q = ggml_view_4d(ctx,
|
|
||||||
flash_grad,
|
|
||||||
src0->ne[0],
|
|
||||||
src0->ne[1],
|
|
||||||
src0->ne[2],
|
|
||||||
src0->ne[3],
|
|
||||||
nb0*src0->ne[0],
|
|
||||||
nb0*src0->ne[0]*src0->ne[1],
|
|
||||||
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
|
|
||||||
offset);
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
enum ggml_type result_type = flash_grad->type;
|
||||||
|
GGML_ASSERT(ggml_blck_size(result_type) == 1);
|
||||||
|
const size_t tsize = ggml_type_size(result_type);
|
||||||
|
|
||||||
|
const size_t offs_q = 0;
|
||||||
|
const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN);
|
||||||
|
const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN);
|
||||||
|
|
||||||
|
if (src0->grad) {
|
||||||
|
struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q);
|
||||||
|
struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0);
|
||||||
src0->grad = ggml_add_or_set(ctx,
|
src0->grad = ggml_add_or_set(ctx,
|
||||||
src0->grad,
|
src0->grad,
|
||||||
grad_q,
|
grad_q,
|
||||||
zero_table);
|
zero_table);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (src1->grad) {
|
if (src1->grad) {
|
||||||
struct ggml_tensor * grad_k = NULL;
|
struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
|
||||||
const size_t nb0 = flash_grad->nb[0];
|
struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1);
|
||||||
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
|
|
||||||
switch(src1->n_dims) {
|
|
||||||
case 2:
|
|
||||||
{
|
|
||||||
grad_k = ggml_view_2d(ctx,
|
|
||||||
flash_grad,
|
|
||||||
src1->ne[0],
|
|
||||||
src1->ne[1],
|
|
||||||
nb0*src1->ne[0],
|
|
||||||
offset);
|
|
||||||
} break;
|
|
||||||
case 3:
|
|
||||||
{
|
|
||||||
grad_k = ggml_view_3d(ctx,
|
|
||||||
flash_grad,
|
|
||||||
src1->ne[0],
|
|
||||||
src1->ne[1],
|
|
||||||
src1->ne[2],
|
|
||||||
nb0*src1->ne[0],
|
|
||||||
nb0*src1->ne[0]*src1->ne[1],
|
|
||||||
offset);
|
|
||||||
} break;
|
|
||||||
case 4:
|
|
||||||
{
|
|
||||||
grad_k = ggml_view_4d(ctx,
|
|
||||||
flash_grad,
|
|
||||||
src1->ne[0],
|
|
||||||
src1->ne[1],
|
|
||||||
src1->ne[2],
|
|
||||||
src1->ne[3],
|
|
||||||
nb0*src1->ne[0],
|
|
||||||
nb0*src1->ne[0]*src1->ne[1],
|
|
||||||
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
|
|
||||||
offset);
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
|
|
||||||
src1->grad = ggml_add_or_set(ctx,
|
src1->grad = ggml_add_or_set(ctx,
|
||||||
src1->grad,
|
src1->grad,
|
||||||
grad_k,
|
grad_k,
|
||||||
zero_table);
|
zero_table);
|
||||||
}
|
}
|
||||||
|
if (src2->grad) {
|
||||||
struct ggml_tensor * opt0 = tensor->src[2];
|
struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
|
||||||
|
struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2);
|
||||||
if (opt0->grad) {
|
src2->grad = ggml_add_or_set(ctx,
|
||||||
struct ggml_tensor * grad_v = NULL;
|
src2->grad,
|
||||||
const size_t nb0 = flash_grad->nb[0];
|
|
||||||
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
|
|
||||||
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
|
|
||||||
switch(opt0->n_dims) {
|
|
||||||
case 2:
|
|
||||||
{
|
|
||||||
grad_v = ggml_view_2d(ctx,
|
|
||||||
flash_grad,
|
|
||||||
opt0->ne[0],
|
|
||||||
opt0->ne[1],
|
|
||||||
nb0*opt0->ne[0],
|
|
||||||
offset);
|
|
||||||
} break;
|
|
||||||
case 3:
|
|
||||||
{
|
|
||||||
grad_v = ggml_view_3d(ctx,
|
|
||||||
flash_grad,
|
|
||||||
opt0->ne[0],
|
|
||||||
opt0->ne[1],
|
|
||||||
opt0->ne[2],
|
|
||||||
nb0*opt0->ne[0],
|
|
||||||
nb0*opt0->ne[0]*opt0->ne[1],
|
|
||||||
offset);
|
|
||||||
} break;
|
|
||||||
case 4:
|
|
||||||
{
|
|
||||||
grad_v = ggml_view_4d(ctx,
|
|
||||||
flash_grad,
|
|
||||||
opt0->ne[0],
|
|
||||||
opt0->ne[1],
|
|
||||||
opt0->ne[2],
|
|
||||||
opt0->ne[3],
|
|
||||||
nb0*opt0->ne[0],
|
|
||||||
nb0*opt0->ne[0]*opt0->ne[1],
|
|
||||||
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
|
|
||||||
offset);
|
|
||||||
} break;
|
|
||||||
}
|
|
||||||
|
|
||||||
opt0->grad = ggml_add_or_set(ctx,
|
|
||||||
opt0->grad,
|
|
||||||
grad_v,
|
grad_v,
|
||||||
zero_table);
|
zero_table);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1359,6 +1359,10 @@ int main(int argc, const char ** argv) {
|
||||||
ggml_new_f32(ctx0, eps))));
|
ggml_new_f32(ctx0, eps))));
|
||||||
|
|
||||||
check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
|
check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY);
|
||||||
|
// NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf.
|
||||||
|
// this may result in different gradients too finite differences.
|
||||||
|
// when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause.
|
||||||
|
// if only the table lookup causes gradients to differ this is acceptable.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1474,28 +1478,31 @@ int main(int argc, const char ** argv) {
|
||||||
|
|
||||||
for (int masked = 0; masked <= 1; ++masked) {
|
for (int masked = 0; masked <= 1; ++masked) {
|
||||||
for (int ndims = 2; ndims <= 4; ++ndims) {
|
for (int ndims = 2; ndims <= 4; ++ndims) {
|
||||||
int64_t neq[4] = { D, N, B, ne[3] };
|
int max_nrep = (ndims >= 3) ? 2 : 1;
|
||||||
int64_t nek[4] = { D, M, B, ne[3] };
|
for (int nrep = 1; nrep < max_nrep; ++nrep) {
|
||||||
int64_t nev[4] = { M, D, B, ne[3] };
|
int64_t neq[4] = { D, N, B*nrep, ne[3] };
|
||||||
if (ndims == 2) {
|
int64_t nek[4] = { D, M, B, ne[3] };
|
||||||
neq[2] = 1; neq[3] = 1;
|
int64_t nev[4] = { M, D, B, ne[3] };
|
||||||
nek[2] = 1; nek[3] = 1;
|
if (ndims == 2) {
|
||||||
nev[2] = 1; nev[3] = 1;
|
neq[2] = 1; neq[3] = 1;
|
||||||
} else if (ndims == 3) {
|
nek[2] = 1; nek[3] = 1;
|
||||||
neq[3] = 1;
|
nev[2] = 1; nev[3] = 1;
|
||||||
nek[3] = 1;
|
} else if (ndims == 3) {
|
||||||
nev[3] = 1;
|
neq[3] = 1;
|
||||||
|
nek[3] = 1;
|
||||||
|
nev[3] = 1;
|
||||||
|
}
|
||||||
|
x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
||||||
|
x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
||||||
|
x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
||||||
|
ggml_set_param(ctx0, x[0]);
|
||||||
|
ggml_set_param(ctx0, x[1]);
|
||||||
|
ggml_set_param(ctx0, x[2]);
|
||||||
|
|
||||||
|
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
||||||
|
|
||||||
|
check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
||||||
}
|
}
|
||||||
x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f);
|
|
||||||
x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f);
|
|
||||||
x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f);
|
|
||||||
ggml_set_param(ctx0, x[0]);
|
|
||||||
ggml_set_param(ctx0, x[1]);
|
|
||||||
ggml_set_param(ctx0, x[2]);
|
|
||||||
|
|
||||||
struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0)));
|
|
||||||
|
|
||||||
check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue