improve softmax backward pass
go from quadratic runtime to linear runtime by simplifying the formulas
This commit is contained in:
parent
ec1aea09ec
commit
4339f8cf28
1 changed files with 18 additions and 20 deletions
38
ggml.c
38
ggml.c
|
@ -10571,27 +10571,25 @@ static void ggml_compute_forward_soft_max_back_f32(
|
|||
// J = diag(y)-y.T*y
|
||||
// dx = J * dy
|
||||
// dxk = sum_i(Jki * dyi)
|
||||
// dxk = sum_i(-yk*yi * dyi) - (-yk*yk)*dyk + (yk - yk*yk)*dyk
|
||||
// dxk = sum_i(-yk*yi * dyi) + yk*dyk
|
||||
// dxk = -yk * sum_i(yi * dyi) + yk*dyk
|
||||
// dxk = -yk * dot(y, dy) + yk*dyk
|
||||
// dxk = yk * (- dot(y, dy) + dyk)
|
||||
// dxk = yk * (dyk - dot(y, dy))
|
||||
//
|
||||
// post-order:
|
||||
// dot_y_dy := dot(y, dy)
|
||||
// dx := dy
|
||||
// dx := dx - dot_y_dy
|
||||
// dx := dx * y
|
||||
|
||||
// quadratic runtime, linear memory
|
||||
for (int k = 0; k < nc; k++) {
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
|
||||
for (int i = 0; i < k; i++) {
|
||||
float Jki = -y[k]*y[i];
|
||||
sum += (ggml_float) Jki * dy[i];
|
||||
}
|
||||
|
||||
float Jkk = y[k] - y[k]*y[k];
|
||||
sum += (ggml_float) Jkk * dy[k];
|
||||
|
||||
for (int i = k+1; i < nc; i++) {
|
||||
float Jki = -y[k]*y[i];
|
||||
sum += (ggml_float) Jki * dy[i];
|
||||
}
|
||||
|
||||
dx[k] = (float) sum;
|
||||
}
|
||||
// linear runtime, no additional memory
|
||||
float dot_y_dy = 0;
|
||||
ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
|
||||
ggml_vec_cpy_f32 (nc, dx, dy);
|
||||
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
||||
ggml_vec_mul_f32 (nc, dx, dx, y);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue