remove out-commented vectorized code of opt_adam

the vectorized code might be bit faster for low number of parameters, but it had a big memory usage overhead
This commit is contained in:
xaedes 2023-07-03 18:56:05 +02:00
parent 0f6a8ab519
commit 87035b96f7
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

38
ggml.c
View file

@ -17417,44 +17417,6 @@ static enum ggml_opt_result ggml_opt_adam(
}
}
}
{
// // update the gradient
// ggml_opt_get_grad(np, ps, g1);
// // m_t = beta1*m_t-1 + (1 - beta1)*g_t
// ggml_vec_scale_f32(nx, m, beta1);
// ggml_vec_mad_f32 (nx, m, g1, 1.0f - beta1);
// // g2 = g1^2
// ggml_vec_sqr_f32 (nx, g2, g1);
// // v_t = beta2*v_t-1 + (1 - beta2)*g_t^2
// ggml_vec_scale_f32(nx, v, beta2);
// ggml_vec_mad_f32 (nx, v, g2, 1.0f - beta2);
// // m^hat = m_t / (1 - beta1^t)
// // v^hat = v_t / (1 - beta2^t)
// // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
// // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
// // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
// // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
// // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
// ggml_vec_cpy_f32 (nx, mh, m);
// ggml_vec_cpy_f32 (nx, vh, v);
// ggml_vec_scale_f32(nx, mh, alpha/(1.0f - powf(beta1, opt->iter)));
// ggml_vec_scale_f32(nx, vh, 1.0f/(1.0f - powf(beta2, opt->iter)));
// ggml_vec_sqrt_f32 (nx, vh, vh);
// ggml_vec_acc1_f32 (nx, vh, eps);
// ggml_vec_div_f32 (nx, mh, mh, vh);
// ggml_vec_scale_f32(nx, x, 1.0f - decay);
// ggml_vec_sub_f32 (nx, x, x, mh);
// // update the parameters
// ggml_opt_set_params(np, ps, x);
}
if (callback) {
callback(callback_data, &sched);