implement AdamW in ggml_opt_adam by adding weight decay parameter (default 0.001f)

also add a schedule parameter (default 1.0f) that can be used to scale alpha and decay according to learning schedule.
setting the decay parameter to zero disables AdamW resulting in normal Adam optimizer.

since the difference between Adam and AdamW is minimal it is not implemented as another optimizer, but integrated into the existing Adam optimizer.
This commit is contained in:
xaedes 2023-05-20 14:54:40 +02:00
parent f4e9ce7998
commit ef17d99f65
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 11 additions and 2 deletions

11
ggml.c
View file

@ -14603,7 +14603,9 @@ static enum ggml_opt_result ggml_opt_adam(
}
// constants
const float alpha = params.adam.alpha;
const float sched = params.adam.sched;
const float decay = params.adam.decay * sched;
const float alpha = params.adam.alpha * sched;
const float beta1 = params.adam.beta1;
const float beta2 = params.adam.beta2;
const float eps = params.adam.eps;
@ -14673,7 +14675,11 @@ static enum ggml_opt_result ggml_opt_adam(
// m^hat = m_t / (1 - beta1^t)
// v^hat = v_t / (1 - beta2^t)
// x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
// x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
// x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
// x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
// x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
// x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
ggml_vec_cpy_f32 (nx, mh, m);
ggml_vec_cpy_f32 (nx, vh, v);
@ -14684,6 +14690,7 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_vec_acc1_f32 (nx, vh, eps);
ggml_vec_div_f32 (nx, mh, mh, vh);
ggml_vec_scale_f32(nx, x, 1.0f - decay);
ggml_vec_sub_f32 (nx, x, x, mh);
// update the parameters

2
ggml.h
View file

@ -1055,6 +1055,8 @@ extern "C" {
struct {
int n_iter;
float sched; // schedule multiplier (fixed, decay or warmup)
float decay; // weight decay for AdamW, use 0.0f to disable
float alpha; // learning rate
float beta1;
float beta2;