improve softmax backward pass

go from quadratic runtime to linear runtime by simplifying the formulas
This commit is contained in:
xaedes 2023-05-14 17:55:02 +02:00
parent ec1aea09ec
commit 4339f8cf28
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

38
ggml.c
View file

@ -10571,27 +10571,25 @@ static void ggml_compute_forward_soft_max_back_f32(
// J = diag(y)-y.T*y
// dx = J * dy
// dxk = sum_i(Jki * dyi)
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
// dxk = -yk * dot(y, dy) + yk*dyk
// dxk = yk * (- dot(y, dy) + dyk)
// dxk = yk * (dyk - dot(y, dy))
//
// post-order:
// dot_y_dy := dot(y, dy)
// dx := dy
// dx := dx - dot_y_dy
// dx := dx * y
// quadratic runtime, linear memory
for (int k = 0; k < nc; k++) {
ggml_float sum = 0.0;
for (int i = 0; i < k; i++) {
float Jki = -y[k]*y[i];
sum += (ggml_float) Jki * dy[i];
}
float Jkk = y[k] - y[k]*y[k];
sum += (ggml_float) Jkk * dy[k];
for (int i = k+1; i < nc; i++) {
float Jki = -y[k]*y[i];
sum += (ggml_float) Jki * dy[i];
}
dx[k] = (float) sum;
}
// linear runtime, no additional memory
float dot_y_dy = 0;
ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
ggml_vec_cpy_f32 (nc, dx, dy);
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
ggml_vec_mul_f32 (nc, dx, dx, y);
#ifndef NDEBUG
for (int i = 0; i < nc; ++i) {