From 4339f8cf285c95af503a7105b747b29ef7b1d64b Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 14 May 2023 17:55:02 +0200 Subject: [PATCH] improve softmax backward pass go from quadratic runtime to linear runtime by simplifying the formulas --- ggml.c | 38 ++++++++++++++++++-------------------- 1 file changed, 18 insertions(+), 20 deletions(-) diff --git a/ggml.c b/ggml.c index 2cc51fcc0..0935549a4 100644 --- a/ggml.c +++ b/ggml.c @@ -10571,27 +10571,25 @@ static void ggml_compute_forward_soft_max_back_f32( // J = diag(y)-y.T*y // dx = J * dy // dxk = sum_i(Jki * dyi) + // dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk + // dxk = sum_i(-yk*yi * dyi) + yk*dyk + // dxk = -yk * sum_i(yi * dyi) + yk*dyk + // dxk = -yk * dot(y, dy) + yk*dyk + // dxk = yk * (- dot(y, dy) + dyk) + // dxk = yk * (dyk - dot(y, dy)) + // + // post-order: + // dot_y_dy := dot(y, dy) + // dx := dy + // dx := dx - dot_y_dy + // dx := dx * y - // quadratic runtime, linear memory - for (int k = 0; k < nc; k++) { - - ggml_float sum = 0.0; - - for (int i = 0; i < k; i++) { - float Jki = -y[k]*y[i]; - sum += (ggml_float) Jki * dy[i]; - } - - float Jkk = y[k] - y[k]*y[k]; - sum += (ggml_float) Jkk * dy[k]; - - for (int i = k+1; i < nc; i++) { - float Jki = -y[k]*y[i]; - sum += (ggml_float) Jki * dy[i]; - } - - dx[k] = (float) sum; - } + // linear runtime, no additional memory + float dot_y_dy = 0; + ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy); + ggml_vec_cpy_f32 (nc, dx, dy); + ggml_vec_acc1_f32(nc, dx, -dot_y_dy); + ggml_vec_mul_f32 (nc, dx, dx, y); #ifndef NDEBUG for (int i = 0; i < nc; ++i) {