From 3164f9338109b139182def39921ad4131d57c1e1 Mon Sep 17 00:00:00 2001 From: xaedes Date: Thu, 1 Jun 2023 19:41:55 +0200 Subject: [PATCH] fix formulas in comments --- ggml.c | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/ggml.c b/ggml.c index b16cd07a9..b3ae7f2f9 100644 --- a/ggml.c +++ b/ggml.c @@ -13592,12 +13592,12 @@ static void ggml_compute_forward_flash_attn_back_f32( // vcur = v[:M,:D,iq2,iq3] [M,D,1,1] grad[vcur] = grad[S5].T @ S4 // S0 = -Inf [D,1,1,1] // ~S1[i] = dot(kcur[:D,i], qcur) - // S1 = qcur.T @ kcur [M,1,1,1] grad[S1] = grad[S2] * scale + // S1 = qcur @ kcur.T [M,1,1,1] grad[S1] = grad[S2] * scale // S2 = S1 * scale [M,1,1,1] grad[S2] = diag_mask_zero(grad[S3], P) // S3 = diag_mask_inf(S2, P) [M,1,1,1] grad[S3] = S4 * (grad[S4] - dot(S4, grad[S4])) // S4 = softmax(S3) [M,1,1,1] grad[S4] = grad[S5] @ vcur - // ~S5[i] = dot(vcur[:,i],S4) - // S5 = S4.T @ vcur [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3] + // ~S5[i] = dot(vcur[:,i], S4) + // S5 = S4 @ vcur.T [D,1,1,1] grad[S5] = d[:D,iq1,iq2,iq3] // ~dst[i,iq1,iq2,iq3] = S5[i] ^ // dst[:D,iq1,iq2,iq3] = S5 | grad[dst[:D,iq1,iq2,iq3]] = d[:D,iq1,iq2,iq3] // dst backward-/ grad[dst] = d @@ -13615,7 +13615,7 @@ static void ggml_compute_forward_flash_attn_back_f32( // // in post-order: // - // S1 = qcur.T @ kcur + // S1 = qcur @ kcur.T // S2 = S1 * scale // S3 = diag_mask_inf(S2, P) // S4 = softmax(S3) @@ -13628,7 +13628,7 @@ static void ggml_compute_forward_flash_attn_back_f32( // // using less variables (SM=S4): // - // S = diag_mask_inf(qcur.T @ kcur * scale, P) + // S = diag_mask_inf(qcur @ kcur.T * scale, P) // SM = softmax(S) // S = d[:D,iq1,iq2,iq3] @ vcur // dot_SM_gradSM = dot(SM, S)