implement backward pass of flash attention
This commit is contained in:
parent
56895e28f6
commit
22a7279ffb
2 changed files with 652 additions and 4 deletions
647
ggml.c
647
ggml.c
|
@ -3336,6 +3336,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
||||||
|
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
"FLASH_FF",
|
"FLASH_FF",
|
||||||
|
"FLASH_ATTN_BACK",
|
||||||
|
|
||||||
"MAP_UNARY",
|
"MAP_UNARY",
|
||||||
"MAP_BINARY",
|
"MAP_BINARY",
|
||||||
|
@ -3344,7 +3345,7 @@ static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");
|
static_assert(GGML_OP_COUNT == 56, "GGML_OP_COUNT != 56");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -3402,6 +3403,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
|
|
||||||
"flash_attn(x)",
|
"flash_attn(x)",
|
||||||
"flash_ff(x)",
|
"flash_ff(x)",
|
||||||
|
"flash_attn_back(x)",
|
||||||
|
|
||||||
"f(x)",
|
"f(x)",
|
||||||
"f(x,y)",
|
"f(x,y)",
|
||||||
|
@ -3410,7 +3412,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 55, "GGML_OP_COUNT != 55");
|
static_assert(GGML_OP_COUNT == 56, "GGML_OP_COUNT != 56");
|
||||||
|
|
||||||
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
|
||||||
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
|
||||||
|
@ -6251,7 +6253,6 @@ struct ggml_tensor * ggml_flash_ff(
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
||||||
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
if (a->grad || b0->grad || b1->grad || c0->grad || c1->grad) {
|
||||||
GGML_ASSERT(false); // TODO: implement backward
|
|
||||||
is_node = true;
|
is_node = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6269,6 +6270,71 @@ struct ggml_tensor * ggml_flash_ff(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_flash_attn_back
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_flash_attn_back(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * q,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * d,
|
||||||
|
bool masked) {
|
||||||
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||||
|
// TODO: check if vT can be multiplied by (k*qT)
|
||||||
|
|
||||||
|
// d shape [D,N,ne2,ne3]
|
||||||
|
// q shape [D,N,ne2,ne3]
|
||||||
|
// k shape [D,M,ne2,ne3]
|
||||||
|
// v shape [M,D,ne2,ne3]
|
||||||
|
|
||||||
|
const int64_t D = q->ne[0];
|
||||||
|
const int64_t N = q->ne[1];
|
||||||
|
const int64_t M = k->ne[1];
|
||||||
|
const int64_t ne2 = q->ne[2];
|
||||||
|
const int64_t ne3 = q->ne[3];
|
||||||
|
|
||||||
|
GGML_ASSERT(k->ne[0] == D);
|
||||||
|
GGML_ASSERT(v->ne[0] == M);
|
||||||
|
GGML_ASSERT(v->ne[1] == D);
|
||||||
|
GGML_ASSERT(d->ne[0] == D);
|
||||||
|
GGML_ASSERT(d->ne[1] == N);
|
||||||
|
GGML_ASSERT(k->ne[2] == ne2);
|
||||||
|
GGML_ASSERT(k->ne[3] == ne3);
|
||||||
|
GGML_ASSERT(v->ne[2] == ne2);
|
||||||
|
GGML_ASSERT(v->ne[3] == ne3);
|
||||||
|
GGML_ASSERT(d->ne[2] == ne2);
|
||||||
|
GGML_ASSERT(d->ne[3] == ne3);
|
||||||
|
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
if (q->grad || k->grad || v->grad) {
|
||||||
|
// when using this operation (in backwards pass) these grads are set.
|
||||||
|
// we don't want to create (big) grad of our result, so is_node is false.
|
||||||
|
is_node = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// store gradients of q, k and v as continuous tensors concatenated in result.
|
||||||
|
// q shape[D,N,ne2,ne3] ; k shape [D,M,ne2,ne3] ; v shape [M,D,ne2,ne3]
|
||||||
|
// gradq->data = result->data
|
||||||
|
// gradk->data = result->data + nb0*D*N*ne2*ne3
|
||||||
|
// gradv->data = result->data + nb0*D*N*ne2*ne3 + nb0*D*M*ne2*ne3
|
||||||
|
// note: v and gradv are actually transposed, i.e. v->ne[0] != D.
|
||||||
|
int64_t ne[4] = {D,M+N+M,ne2,ne3};
|
||||||
|
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
|
||||||
|
|
||||||
|
result->op = GGML_OP_FLASH_ATTN_BACK;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src0 = q;
|
||||||
|
result->src1 = k;
|
||||||
|
result->opt[0] = v;
|
||||||
|
result->opt[1] = d;
|
||||||
|
result->opt[2] = ggml_new_i32(ctx, masked ? 1 : 0);
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
// ggml_map_unary
|
// ggml_map_unary
|
||||||
|
|
||||||
struct ggml_tensor * ggml_map_unary_impl_f32(
|
struct ggml_tensor * ggml_map_unary_impl_f32(
|
||||||
|
@ -12788,6 +12854,394 @@ static void ggml_compute_forward_flash_ff(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_flash_attn_back
|
||||||
|
|
||||||
|
static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * q,
|
||||||
|
const struct ggml_tensor * k,
|
||||||
|
const struct ggml_tensor * v,
|
||||||
|
const struct ggml_tensor * d,
|
||||||
|
const bool masked,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
int64_t t0 = ggml_perf_time_us();
|
||||||
|
UNUSED(t0);
|
||||||
|
|
||||||
|
const int64_t neq0 = q->ne[0];
|
||||||
|
const int64_t neq1 = q->ne[1];
|
||||||
|
const int64_t neq2 = q->ne[2];
|
||||||
|
const int64_t neq3 = q->ne[3];
|
||||||
|
|
||||||
|
const int64_t nek0 = k->ne[0];
|
||||||
|
const int64_t nek1 = k->ne[1];
|
||||||
|
//const int64_t nek2 = k->ne[2];
|
||||||
|
//const int64_t nek3 = k->ne[3];
|
||||||
|
|
||||||
|
const int64_t nev0 = v->ne[0];
|
||||||
|
const int64_t nev1 = v->ne[1];
|
||||||
|
//const int64_t nev2 = v->ne[2];
|
||||||
|
//const int64_t nev3 = v->ne[3];
|
||||||
|
|
||||||
|
const int64_t ne0 = dst->ne[0];
|
||||||
|
const int64_t ne1 = dst->ne[1];
|
||||||
|
//const int64_t ne2 = dst->ne[2];
|
||||||
|
//const int64_t ne3 = dst->ne[3];
|
||||||
|
|
||||||
|
const int nbk0 = k->nb[0];
|
||||||
|
const int nbk1 = k->nb[1];
|
||||||
|
const int nbk2 = k->nb[2];
|
||||||
|
const int nbk3 = k->nb[3];
|
||||||
|
|
||||||
|
const int nbq0 = q->nb[0];
|
||||||
|
const int nbq1 = q->nb[1];
|
||||||
|
const int nbq2 = q->nb[2];
|
||||||
|
const int nbq3 = q->nb[3];
|
||||||
|
|
||||||
|
const int nbv0 = v->nb[0];
|
||||||
|
const int nbv1 = v->nb[1];
|
||||||
|
const int nbv2 = v->nb[2];
|
||||||
|
const int nbv3 = v->nb[3];
|
||||||
|
|
||||||
|
const int nbd0 = d->nb[0];
|
||||||
|
const int nbd1 = d->nb[1];
|
||||||
|
const int nbd2 = d->nb[2];
|
||||||
|
const int nbd3 = d->nb[3];
|
||||||
|
|
||||||
|
const int nb0 = dst->nb[0];
|
||||||
|
const int nb1 = dst->nb[1];
|
||||||
|
const int nb2 = dst->nb[2];
|
||||||
|
const int nb3 = dst->nb[3];
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int64_t D = neq0;
|
||||||
|
const int64_t N = neq1;
|
||||||
|
const int64_t P = nek1 - N;
|
||||||
|
const int64_t M = P + N;
|
||||||
|
|
||||||
|
const int Mup = ggml_up(M, GGML_SOFT_MAX_UNROLL);
|
||||||
|
const int mxDM = MAX(D, Mup);
|
||||||
|
|
||||||
|
GGML_ASSERT(ne0 == D);
|
||||||
|
GGML_ASSERT(ne1 == N);
|
||||||
|
GGML_ASSERT(P >= 0);
|
||||||
|
|
||||||
|
GGML_ASSERT(nbq0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nbk0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nbv0 == sizeof(float));
|
||||||
|
|
||||||
|
GGML_ASSERT(neq0 == D);
|
||||||
|
GGML_ASSERT(nek0 == D);
|
||||||
|
GGML_ASSERT(nev1 == D);
|
||||||
|
|
||||||
|
GGML_ASSERT(neq1 == N);
|
||||||
|
GGML_ASSERT(nek1 == N + P);
|
||||||
|
GGML_ASSERT(nev1 == D);
|
||||||
|
|
||||||
|
// dst cannot be transposed or permuted
|
||||||
|
GGML_ASSERT(nb0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb0 <= nb1);
|
||||||
|
GGML_ASSERT(nb1 <= nb2);
|
||||||
|
GGML_ASSERT(nb2 <= nb3);
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_INIT) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// parallelize by q rows using ggml_vec_dot_f32
|
||||||
|
|
||||||
|
// total rows in q
|
||||||
|
const int nr = neq1*neq2*neq3;
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
const float scale = 1.0f/sqrtf(D);
|
||||||
|
|
||||||
|
//printf("P=%d N=%d D=%d ir0=%d ir1=%d scale = %f\n", P, N, D, ir0, ir1, scale);
|
||||||
|
|
||||||
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
|
// q indices
|
||||||
|
const int iq3 = ir/(neq2*neq1);
|
||||||
|
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||||
|
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||||
|
|
||||||
|
// not sure about CACHE_LINE_SIZE_F32..
|
||||||
|
// - maybe it must not be multiplied by 2 and excluded from .. in SM 1*(..) offset?
|
||||||
|
float * S = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 0*(mxDM+CACHE_LINE_SIZE_F32);
|
||||||
|
float * SM = (float *) params->wdata + ith*2*(mxDM + CACHE_LINE_SIZE_F32) + 1*(mxDM+CACHE_LINE_SIZE_F32);
|
||||||
|
|
||||||
|
for (int i = M; i < Mup; ++i) {
|
||||||
|
S[i] = -INFINITY;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||||
|
// k indices
|
||||||
|
const int ik3 = iq3;
|
||||||
|
const int ik2 = iq2;
|
||||||
|
const int ik1 = ic;
|
||||||
|
|
||||||
|
// S indices
|
||||||
|
const int i1 = ik1;
|
||||||
|
|
||||||
|
ggml_vec_dot_f32(neq0,
|
||||||
|
S + i1,
|
||||||
|
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||||
|
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// scale
|
||||||
|
ggml_vec_scale_f32(nek1, S, scale);
|
||||||
|
|
||||||
|
if (masked) {
|
||||||
|
for (int64_t i = P; i < M; i++) {
|
||||||
|
if (i > P + iq1) {
|
||||||
|
S[i] = -INFINITY;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// softmax
|
||||||
|
{
|
||||||
|
float max = -INFINITY;
|
||||||
|
ggml_vec_max_f32(M, &max, S);
|
||||||
|
|
||||||
|
ggml_float sum = 0.0;
|
||||||
|
{
|
||||||
|
#ifdef GGML_SOFT_MAX_ACCELERATE
|
||||||
|
max = -max;
|
||||||
|
vDSP_vsadd(SM, 1, &max, SM, 1, Mup);
|
||||||
|
vvexpf(SM, SM, &Mup);
|
||||||
|
ggml_vec_sum_f32(Mup, &sum, SM);
|
||||||
|
#else
|
||||||
|
uint16_t scvt[GGML_SOFT_MAX_UNROLL];
|
||||||
|
ggml_float sump[GGML_SOFT_MAX_UNROLL] = { 0.0 };
|
||||||
|
|
||||||
|
for (int i = 0; i < Mup; i += GGML_SOFT_MAX_UNROLL) {
|
||||||
|
float * SS = SM + i;
|
||||||
|
|
||||||
|
for (int j = 0; j < GGML_SOFT_MAX_UNROLL; ++j) {
|
||||||
|
if (SS[j] == -INFINITY) {
|
||||||
|
SS[j] = 0.0f;
|
||||||
|
} else {
|
||||||
|
ggml_fp16_t s = GGML_FP32_TO_FP16(SS[j] - max);
|
||||||
|
memcpy(&scvt[j], &s, sizeof(uint16_t));
|
||||||
|
const float val = GGML_FP16_TO_FP32(table_exp_f16[scvt[j]]);
|
||||||
|
sump[j] += (ggml_float)val;
|
||||||
|
SS[j] = val;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < GGML_SOFT_MAX_UNROLL; i++) {
|
||||||
|
sum += sump[i];
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
assert(sum > 0.0);
|
||||||
|
|
||||||
|
sum = 1.0/sum;
|
||||||
|
ggml_vec_scale_f32(M, SM, sum);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// step-by-step explanation
|
||||||
|
{
|
||||||
|
// forward-process shape grads from backward process
|
||||||
|
// parallel_for iq2,iq3:
|
||||||
|
// k[:D,:M,:,:] [D,M,:,:] grad[k][:D,:M,iq2,iq3] += grad[kcur]
|
||||||
|
// q[:D,:N,:,:] [D,N,:,:] grad[q][:D,iq1,iq2,iq3] += grad[qcur]
|
||||||
|
// v[:M,:D,:,:] [M,D,:,:] grad[v][:M,:D,iq2,iq3] += grad[vcur]
|
||||||
|
// for iq1:
|
||||||
|
// kcur = k[:D,:M,iq2,iq3] [D,M,1,1] grad[kcur] = grad[S1].T @ qcur
|
||||||
|
// qcur = q[:D,iq1,iq2,iq3] [D,1,1,1] grad[qcur] = grad[S1] @ kcur
|
||||||
|
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
||||||
|
// S0 = -Inf [D,1,1,1]
|
||||||
|
// ~S1[i] = dot(kcur[:D,i], qcur)
|
||||||
|
// S1 = qcur.T @ kcur [M,1,1,1] grad[S1] = grad[S2] * scale
|
||||||
|
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
||||||
|
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
||||||
|
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
||||||
|
// ~S5[i] = dot(vcur[:,i],S4)
|
||||||
|
// S5 = S4.T @ vcur [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
||||||
|
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
||||||
|
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
||||||
|
// dst backward-/ grad[dst] = d
|
||||||
|
//
|
||||||
|
// output gradients with their dependencies:
|
||||||
|
//
|
||||||
|
// grad[kcur] = grad[S1].T @ qcur
|
||||||
|
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
||||||
|
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
||||||
|
// grad[S4] = grad[S5] @ vcur
|
||||||
|
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
||||||
|
// grad[qcur] = grad[S1] @ kcur
|
||||||
|
// grad[vcur] = grad[S5].T @ S4
|
||||||
|
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
||||||
|
//
|
||||||
|
// in post-order:
|
||||||
|
//
|
||||||
|
// S1 = qcur.T @ kcur
|
||||||
|
// S2 = S1 * scale
|
||||||
|
// S3 = diag_mask_inf(S2, P)
|
||||||
|
// S4 = softmax(S3)
|
||||||
|
// grad[S4] = d[:D,iq1,iq2,iq3] @ vcur
|
||||||
|
// grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
||||||
|
// grad[S1] = diag_mask_zero(grad[S3], P) * scale
|
||||||
|
// grad[qcur] = grad[S1] @ kcur
|
||||||
|
// grad[kcur] = grad[S1].T @ qcur
|
||||||
|
// grad[vcur] = d[:D,iq1,iq2,iq3].T @ S4
|
||||||
|
//
|
||||||
|
// using less variables (SM=S4):
|
||||||
|
//
|
||||||
|
// S = diag_mask_inf(qcur.T @ kcur * scale, P)
|
||||||
|
// SM = softmax(S)
|
||||||
|
// S = d[:D,iq1,iq2,iq3] @ vcur
|
||||||
|
// dot_SM_gradSM = dot(SM, S)
|
||||||
|
// S = SM * (S - dot(SM, S))
|
||||||
|
// S = diag_mask_zero(S, P) * scale
|
||||||
|
//
|
||||||
|
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
||||||
|
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
||||||
|
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
||||||
|
}
|
||||||
|
|
||||||
|
// S = gradSM = d[:D,iq1,iq2,iq3] @ vcur
|
||||||
|
// S = d[:D,iq1,iq2,iq3] @ vcur
|
||||||
|
// S[:M] += vcur[:,ic] * d[ic,iq1,iq2,iq3]
|
||||||
|
ggml_vec_set_f32(D, S, 0);
|
||||||
|
for (int64_t ic = 0; ic < D; ++ic) {
|
||||||
|
// dst indices
|
||||||
|
const int i1 = iq1;
|
||||||
|
const int i2 = iq2;
|
||||||
|
const int i3 = iq3;
|
||||||
|
|
||||||
|
ggml_vec_mad_f32(M,
|
||||||
|
S,
|
||||||
|
(float *) ((char *) v->data + ( ic*nbv1 + i2*nbv2 + i3*nbv3)),
|
||||||
|
*(float *) ((char *) d->data + (ic*nbd1 + i1*nbd2 + i2*nbd2 + i3*nbd3)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// S = SM * (S - dot(SM, S))
|
||||||
|
float dot_SM_gradSM = 0;
|
||||||
|
ggml_vec_dot_f32 (M, &dot_SM_gradSM, SM, S);
|
||||||
|
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
||||||
|
ggml_vec_mul_f32 (M, S, S, SM);
|
||||||
|
|
||||||
|
// S = diag_mask_zero(S, P) * scale
|
||||||
|
if (masked) {
|
||||||
|
for (int64_t i = P + iq1 + 1; i < M; i++) {
|
||||||
|
S[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
ggml_vec_scale_f32(M, S, scale);
|
||||||
|
|
||||||
|
void * grad_q = (char *) dst->data;
|
||||||
|
void * grad_k = (char *) dst->data + nb0*D*N*neq2*neq3;
|
||||||
|
void * grad_v = (char *) dst->data + nb0*D*N*neq2*neq3 + nb0*D*M*neq2*neq3;
|
||||||
|
|
||||||
|
const size_t nbgq1 = nb0*neq0;
|
||||||
|
const size_t nbgq2 = nb0*neq0*neq1;
|
||||||
|
const size_t nbgq3 = nb0*neq0*neq1*neq2;
|
||||||
|
|
||||||
|
const size_t nbgk1 = nb0*nek0;
|
||||||
|
const size_t nbgk2 = nb0*nek0*nek1;
|
||||||
|
const size_t nbgk3 = nb0*nek0*nek1*neq2;
|
||||||
|
|
||||||
|
const size_t nbgv1 = nb0*nev0;
|
||||||
|
const size_t nbgv2 = nb0*nev0*nev1;
|
||||||
|
const size_t nbgv3 = nb0*nev0*nev1*neq2;
|
||||||
|
|
||||||
|
// S shape [M,1]
|
||||||
|
// SM shape [M,1]
|
||||||
|
// kcur shape [D,M]
|
||||||
|
// qcur shape [D,1]
|
||||||
|
// vcur shape [M,D]
|
||||||
|
//
|
||||||
|
// grad[q][:D,iq1,iq2,iq3] += S @ kcur
|
||||||
|
// grad[q][:D,iq1,iq2,iq3] += shape[M,1] @ shape[D,M]
|
||||||
|
// grad[q][ic,iq1,iq2,iq3] += dot(kcur[:,ic],S.T)
|
||||||
|
// grad[q][ic,iq1,iq2,iq3] += dot(k[:D,ic,iq2,iq3],S.T)
|
||||||
|
for (int64_t ic = 0; ic < M; ++ic) {
|
||||||
|
// dst indices
|
||||||
|
const int i1 = iq1;
|
||||||
|
const int i2 = iq2;
|
||||||
|
const int i3 = iq3;
|
||||||
|
|
||||||
|
ggml_vec_dot_f32(D,
|
||||||
|
(float *) ((char *) grad_q + (ic*nb0 + i1*nbgq1 + i2*nbgq2 + i3*nbgq3)),
|
||||||
|
(float *) ((char *) k->data + ( ic*nbk1 + i2*nbk2 + i3*nbk3)),
|
||||||
|
S);
|
||||||
|
}
|
||||||
|
|
||||||
|
// grad[k][:D,:M,iq2,iq3] += S.T @ qcur
|
||||||
|
// grad[k][:D,ic,iq2,iq3] += S.T[0,ic] * qcur[:D,0]
|
||||||
|
// grad[k][:D,ic,iq2,iq3] += S[ic] * qcur[:D,0]
|
||||||
|
for (int64_t ic = 0; ic < M; ++ic) {
|
||||||
|
// dst indices
|
||||||
|
const int i1 = iq1;
|
||||||
|
const int i2 = iq2;
|
||||||
|
const int i3 = iq3;
|
||||||
|
|
||||||
|
ggml_vec_set_f32(D,
|
||||||
|
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
||||||
|
0);
|
||||||
|
ggml_vec_mad_f32(D,
|
||||||
|
(float *) ((char *) grad_k + (ic*nbgk1 + i2*nbgk2 + i3*nbgk3)),
|
||||||
|
(float *) ((char *) q->data + (i1*nbk1 + i2*nbk2 + i3*nbk3)),
|
||||||
|
S[ic]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// grad[v][:M,:D,iq2,iq3] += d[:D,iq1,iq2,iq3].T @ SM
|
||||||
|
// grad[v][:M,ic,iq2,iq3] += d[:D,iq1,iq2,iq3].T[0,ic] * SM[:M]
|
||||||
|
// grad[v][:M,ic,iq2,iq3] += d[ic,iq1,iq2,iq3] * SM[:M]
|
||||||
|
for (int64_t ic = 0; ic < D; ++ic) {
|
||||||
|
// dst indices
|
||||||
|
const int i1 = iq1;
|
||||||
|
const int i2 = iq2;
|
||||||
|
const int i3 = iq3;
|
||||||
|
|
||||||
|
ggml_vec_set_f32(M,
|
||||||
|
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
||||||
|
0);
|
||||||
|
ggml_vec_mad_f32(M,
|
||||||
|
(float *) ((char *) grad_v + ( ic*nbgv1 + i2*nbgv2 + i3*nbgv3)),
|
||||||
|
SM,
|
||||||
|
*(float *) ((char *) d->data + (ic*nbd1 + i1*nbd2 + i2*nbd2 + i3*nbd3)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_flash_attn_back(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * q,
|
||||||
|
const struct ggml_tensor * k,
|
||||||
|
const struct ggml_tensor * v,
|
||||||
|
const struct ggml_tensor * d,
|
||||||
|
const bool masked,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
switch (q->type) {
|
||||||
|
case GGML_TYPE_F32:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_flash_attn_back_f32(params, q, k, v, d, masked, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_map_unary
|
// ggml_compute_forward_map_unary
|
||||||
|
|
||||||
static void ggml_compute_forward_map_unary_f32(
|
static void ggml_compute_forward_map_unary_f32(
|
||||||
|
@ -13371,6 +13825,13 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
{
|
{
|
||||||
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
ggml_compute_forward_flash_ff(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], tensor->opt[2], tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
|
{
|
||||||
|
int32_t t = ggml_get_i32_1d(tensor->opt[2], 0);
|
||||||
|
GGML_ASSERT(t == 0 || t == 1);
|
||||||
|
bool masked = t != 0;
|
||||||
|
ggml_compute_forward_flash_attn_back(params, tensor->src0, tensor->src1, tensor->opt[0], tensor->opt[1], masked, tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
{
|
{
|
||||||
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
|
const ggml_unary_op_f32_t fun = *((ggml_unary_op_f32_t *)tensor->opt[0]->data);
|
||||||
|
@ -14007,12 +14468,169 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
case GGML_OP_FLASH_ATTN:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false); // not supported
|
struct ggml_tensor * flash_grad = NULL;
|
||||||
|
if (src0->grad || src1->grad || tensor->opt[0]->grad) {
|
||||||
|
int32_t t = ggml_get_i32_1d(tensor->opt[1], 0);
|
||||||
|
GGML_ASSERT(t == 0 || t == 1);
|
||||||
|
bool masked = t != 0;
|
||||||
|
flash_grad =
|
||||||
|
ggml_flash_attn_back(ctx,
|
||||||
|
src0->grad,
|
||||||
|
src1->grad,
|
||||||
|
tensor->opt[0]->grad,
|
||||||
|
tensor->grad,
|
||||||
|
masked);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src0->grad) {
|
||||||
|
struct ggml_tensor * grad_q = NULL;
|
||||||
|
const size_t nb0 = flash_grad->nb[0];
|
||||||
|
const size_t offset = 0;
|
||||||
|
switch(src0->n_dims) {
|
||||||
|
case 2:
|
||||||
|
{
|
||||||
|
grad_q = ggml_view_2d(ctx,
|
||||||
|
flash_grad,
|
||||||
|
src0->ne[0],
|
||||||
|
src0->ne[1],
|
||||||
|
nb0*src0->ne[0],
|
||||||
|
offset);
|
||||||
|
} break;
|
||||||
|
case 3:
|
||||||
|
{
|
||||||
|
grad_q = ggml_view_3d(ctx,
|
||||||
|
flash_grad,
|
||||||
|
src0->ne[0],
|
||||||
|
src0->ne[1],
|
||||||
|
src0->ne[2],
|
||||||
|
nb0*src0->ne[0],
|
||||||
|
nb0*src0->ne[0]*src0->ne[1],
|
||||||
|
offset);
|
||||||
|
} break;
|
||||||
|
case 4:
|
||||||
|
{
|
||||||
|
grad_q = ggml_view_3d(ctx,
|
||||||
|
flash_grad,
|
||||||
|
src0->ne[0],
|
||||||
|
src0->ne[1],
|
||||||
|
src0->ne[2],
|
||||||
|
src0->ne[3],
|
||||||
|
nb0*src0->ne[0],
|
||||||
|
nb0*src0->ne[0]*src0->ne[1],
|
||||||
|
nb0*src0->ne[0]*src0->ne[1]*src0->ne[2],
|
||||||
|
offset);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
|
||||||
|
src0->grad = ggml_add_impl(ctx,
|
||||||
|
src0->grad,
|
||||||
|
grad_q,
|
||||||
|
inplace);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (src1->grad) {
|
||||||
|
struct ggml_tensor * grad_k = NULL;
|
||||||
|
const size_t nb0 = flash_grad->nb[0];
|
||||||
|
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3];
|
||||||
|
switch(src1->n_dims) {
|
||||||
|
case 2:
|
||||||
|
{
|
||||||
|
grad_k = ggml_view_2d(ctx,
|
||||||
|
flash_grad,
|
||||||
|
src1->ne[0],
|
||||||
|
src1->ne[1],
|
||||||
|
nb0*src1->ne[0],
|
||||||
|
offset);
|
||||||
|
} break;
|
||||||
|
case 3:
|
||||||
|
{
|
||||||
|
grad_k = ggml_view_3d(ctx,
|
||||||
|
flash_grad,
|
||||||
|
src1->ne[0],
|
||||||
|
src1->ne[1],
|
||||||
|
src1->ne[2],
|
||||||
|
nb0*src1->ne[0],
|
||||||
|
nb0*src1->ne[0]*src1->ne[1],
|
||||||
|
offset);
|
||||||
|
} break;
|
||||||
|
case 4:
|
||||||
|
{
|
||||||
|
grad_k = ggml_view_3d(ctx,
|
||||||
|
flash_grad,
|
||||||
|
src1->ne[0],
|
||||||
|
src1->ne[1],
|
||||||
|
src1->ne[2],
|
||||||
|
src1->ne[3],
|
||||||
|
nb0*src1->ne[0],
|
||||||
|
nb0*src1->ne[0]*src1->ne[1],
|
||||||
|
nb0*src1->ne[0]*src1->ne[1]*src1->ne[2],
|
||||||
|
offset);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
|
||||||
|
src1->grad = ggml_add_impl(ctx,
|
||||||
|
src1->grad,
|
||||||
|
grad_k,
|
||||||
|
inplace);
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_tensor * opt0 = tensor->opt[0];
|
||||||
|
|
||||||
|
if (opt0->grad) {
|
||||||
|
struct ggml_tensor * grad_v = NULL;
|
||||||
|
const size_t nb0 = flash_grad->nb[0];
|
||||||
|
const size_t offset = nb0*src0->ne[0]*src0->ne[1]*src0->ne[2]*src0->ne[3]
|
||||||
|
+ nb0*src1->ne[0]*src1->ne[1]*src1->ne[2]*src1->ne[3];
|
||||||
|
switch(opt0->n_dims) {
|
||||||
|
case 2:
|
||||||
|
{
|
||||||
|
grad_v = ggml_view_2d(ctx,
|
||||||
|
flash_grad,
|
||||||
|
opt0->ne[0],
|
||||||
|
opt0->ne[1],
|
||||||
|
nb0*opt0->ne[0],
|
||||||
|
offset);
|
||||||
|
} break;
|
||||||
|
case 3:
|
||||||
|
{
|
||||||
|
grad_v = ggml_view_3d(ctx,
|
||||||
|
flash_grad,
|
||||||
|
opt0->ne[0],
|
||||||
|
opt0->ne[1],
|
||||||
|
opt0->ne[2],
|
||||||
|
nb0*opt0->ne[0],
|
||||||
|
nb0*opt0->ne[0]*opt0->ne[1],
|
||||||
|
offset);
|
||||||
|
} break;
|
||||||
|
case 4:
|
||||||
|
{
|
||||||
|
grad_v = ggml_view_3d(ctx,
|
||||||
|
flash_grad,
|
||||||
|
opt0->ne[0],
|
||||||
|
opt0->ne[1],
|
||||||
|
opt0->ne[2],
|
||||||
|
opt0->ne[3],
|
||||||
|
nb0*opt0->ne[0],
|
||||||
|
nb0*opt0->ne[0]*opt0->ne[1],
|
||||||
|
nb0*opt0->ne[0]*opt0->ne[1]*opt0->ne[2],
|
||||||
|
offset);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
|
||||||
|
opt0->grad = ggml_add_impl(ctx,
|
||||||
|
opt0->grad,
|
||||||
|
grad_v,
|
||||||
|
inplace);
|
||||||
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_FF:
|
case GGML_OP_FLASH_FF:
|
||||||
{
|
{
|
||||||
GGML_ASSERT(false); // not supported
|
GGML_ASSERT(false); // not supported
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(false); // not supported
|
||||||
|
} break;
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
case GGML_OP_MAP_BINARY:
|
case GGML_OP_MAP_BINARY:
|
||||||
{
|
{
|
||||||
|
@ -14575,6 +15193,27 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
||||||
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2
|
||||||
}
|
}
|
||||||
|
|
||||||
|
work_size = MAX(work_size, cur);
|
||||||
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN_BACK:
|
||||||
|
{
|
||||||
|
node->n_tasks = n_threads;
|
||||||
|
|
||||||
|
size_t cur = 0;
|
||||||
|
|
||||||
|
const int64_t D = node->src0->ne[0];
|
||||||
|
const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL);
|
||||||
|
const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back
|
||||||
|
if (node->src1->type == GGML_TYPE_F32) {
|
||||||
|
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
||||||
|
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
||||||
|
}
|
||||||
|
|
||||||
|
if (node->src1->type == GGML_TYPE_F16) {
|
||||||
|
cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1)
|
||||||
|
cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2
|
||||||
|
}
|
||||||
|
|
||||||
work_size = MAX(work_size, cur);
|
work_size = MAX(work_size, cur);
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_MAP_UNARY:
|
case GGML_OP_MAP_UNARY:
|
||||||
|
|
9
ggml.h
9
ggml.h
|
@ -318,6 +318,7 @@ extern "C" {
|
||||||
|
|
||||||
GGML_OP_FLASH_ATTN,
|
GGML_OP_FLASH_ATTN,
|
||||||
GGML_OP_FLASH_FF,
|
GGML_OP_FLASH_FF,
|
||||||
|
GGML_OP_FLASH_ATTN_BACK,
|
||||||
|
|
||||||
GGML_OP_MAP_UNARY,
|
GGML_OP_MAP_UNARY,
|
||||||
GGML_OP_MAP_BINARY,
|
GGML_OP_MAP_BINARY,
|
||||||
|
@ -952,6 +953,14 @@ extern "C" {
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
bool masked);
|
bool masked);
|
||||||
|
|
||||||
|
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * q,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * d,
|
||||||
|
bool masked);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_flash_ff(
|
GGML_API struct ggml_tensor * ggml_flash_ff(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue