fix formulas in comments

This commit is contained in:
xaedes 2023-06-01 19:41:55 +02:00
parent 0e269665cd
commit 3164f93381
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

10
ggml.c
View file

@ -13592,12 +13592,12 @@ static void ggml_compute_forward_flash_attn_back_f32(
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
// S0 = -Inf [D,1,1,1]
// ~S1[i] = dot(kcur[:D,i], qcur)
// S1 = qcur.T @ kcur [M,1,1,1] grad[S1] = grad[S2] * scale
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
// ~S5[i] = dot(vcur[:,i],S4)
// S5 = S4.T @ vcur [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
// ~S5[i] = dot(vcur[:,i], S4)
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
// dst backward-/ grad[dst] = d
@ -13615,7 +13615,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
//
// in post-order:
//
// S1 = qcur.T @ kcur
// S1 = qcur @ kcur.T
// S2 = S1 * scale
// S3 = diag_mask_inf(S2, P)
// S4 = softmax(S3)
@ -13628,7 +13628,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
//
// using less variables (SM=S4):
//
// S = diag_mask_inf(qcur.T @ kcur * scale, P)
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
// SM = softmax(S)
// S = d[:D,iq1,iq2,iq3] @ vcur
// dot_SM_gradSM = dot(SM, S)