fix formulas in comments
This commit is contained in:
parent
0e269665cd
commit
3164f93381
1 changed files with 5 additions and 5 deletions
10
ggml.c
10
ggml.c
|
@ -13592,12 +13592,12 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
// vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4
|
||||||
// S0 = -Inf [D,1,1,1]
|
// S0 = -Inf [D,1,1,1]
|
||||||
// ~S1[i] = dot(kcur[:D,i], qcur)
|
// ~S1[i] = dot(kcur[:D,i], qcur)
|
||||||
// S1 = qcur.T @ kcur [M,1,1,1] grad[S1] = grad[S2] * scale
|
// S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale
|
||||||
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
// S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P)
|
||||||
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
// S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4]))
|
||||||
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
// S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur
|
||||||
// ~S5[i] = dot(vcur[:,i],S4)
|
// ~S5[i] = dot(vcur[:,i], S4)
|
||||||
// S5 = S4.T @ vcur [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
// S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3]
|
||||||
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
// ~dst[i,iq1,iq2,iq3] = S5[i] ^
|
||||||
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
// dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3]
|
||||||
// dst backward-/ grad[dst] = d
|
// dst backward-/ grad[dst] = d
|
||||||
|
@ -13615,7 +13615,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
//
|
//
|
||||||
// in post-order:
|
// in post-order:
|
||||||
//
|
//
|
||||||
// S1 = qcur.T @ kcur
|
// S1 = qcur @ kcur.T
|
||||||
// S2 = S1 * scale
|
// S2 = S1 * scale
|
||||||
// S3 = diag_mask_inf(S2, P)
|
// S3 = diag_mask_inf(S2, P)
|
||||||
// S4 = softmax(S3)
|
// S4 = softmax(S3)
|
||||||
|
@ -13628,7 +13628,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
||||||
//
|
//
|
||||||
// using less variables (SM=S4):
|
// using less variables (SM=S4):
|
||||||
//
|
//
|
||||||
// S = diag_mask_inf(qcur.T @ kcur * scale, P)
|
// S = diag_mask_inf(qcur @ kcur.T * scale, P)
|
||||||
// SM = softmax(S)
|
// SM = softmax(S)
|
||||||
// S = d[:D,iq1,iq2,iq3] @ vcur
|
// S = d[:D,iq1,iq2,iq3] @ vcur
|
||||||
// dot_SM_gradSM = dot(SM, S)
|
// dot_SM_gradSM = dot(SM, S)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue