From ef17d99f657aef57b977ce33fc13345f467b1f44 Mon Sep 17 00:00:00 2001 From: xaedes Date: Sat, 20 May 2023 14:54:40 +0200 Subject: [PATCH] implement AdamW in ggml_opt_adam by adding weight decay parameter (default 0.001f) also add a schedule parameter (default 1.0f) that can be used to scale alpha and decay according to learning schedule. setting the decay parameter to zero disables AdamW resulting in normal Adam optimizer. since the difference between Adam and AdamW is minimal it is not implemented as another optimizer, but integrated into the existing Adam optimizer. --- ggml.c | 11 +++++++++-- ggml.h | 2 ++ 2 files changed, 11 insertions(+), 2 deletions(-) diff --git a/ggml.c b/ggml.c index 7b66846ee..f259f5605 100644 --- a/ggml.c +++ b/ggml.c @@ -14603,7 +14603,9 @@ static enum ggml_opt_result ggml_opt_adam( } // constants - const float alpha = params.adam.alpha; + const float sched = params.adam.sched; + const float decay = params.adam.decay * sched; + const float alpha = params.adam.alpha * sched; const float beta1 = params.adam.beta1; const float beta2 = params.adam.beta2; const float eps = params.adam.eps; @@ -14673,7 +14675,11 @@ static enum ggml_opt_result ggml_opt_adam( // m^hat = m_t / (1 - beta1^t) // v^hat = v_t / (1 - beta2^t) - // x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps) + // x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1) + // x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1 + // x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps) + // x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps) + // x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay) ggml_vec_cpy_f32 (nx, mh, m); ggml_vec_cpy_f32 (nx, vh, v); @@ -14684,6 +14690,7 @@ static enum ggml_opt_result ggml_opt_adam( ggml_vec_acc1_f32 (nx, vh, eps); ggml_vec_div_f32 (nx, mh, mh, vh); + ggml_vec_scale_f32(nx, x, 1.0f - decay); ggml_vec_sub_f32 (nx, x, x, mh); // update the parameters diff --git a/ggml.h b/ggml.h index 62da0bd35..6ce660c74 100644 --- a/ggml.h +++ b/ggml.h @@ -1055,6 +1055,8 @@ extern "C" { struct { int n_iter; + float sched; // schedule multiplier (fixed, decay or warmup) + float decay; // weight decay for AdamW, use 0.0f to disable float alpha; // learning rate float beta1; float beta2;