implement AdamW in ggml_opt_adam by adding weight decay parameter (default 0.001f)
also add a schedule parameter (default 1.0f) that can be used to scale alpha and decay according to learning schedule. setting the decay parameter to zero disables AdamW resulting in normal Adam optimizer. since the difference between Adam and AdamW is minimal it is not implemented as another optimizer, but integrated into the existing Adam optimizer.
This commit is contained in:
parent
f4e9ce7998
commit
ef17d99f65
2 changed files with 11 additions and 2 deletions
11
ggml.c
11
ggml.c
|
@ -14603,7 +14603,9 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
}
|
||||
|
||||
// constants
|
||||
const float alpha = params.adam.alpha;
|
||||
const float sched = params.adam.sched;
|
||||
const float decay = params.adam.decay * sched;
|
||||
const float alpha = params.adam.alpha * sched;
|
||||
const float beta1 = params.adam.beta1;
|
||||
const float beta2 = params.adam.beta2;
|
||||
const float eps = params.adam.eps;
|
||||
|
@ -14673,7 +14675,11 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
|
||||
// m^hat = m_t / (1 - beta1^t)
|
||||
// v^hat = v_t / (1 - beta2^t)
|
||||
// x_t = x_t-1 - alpha*m^hat/(sqrt(v^hat) + eps)
|
||||
// x_t = x_t-1 - sched*(alpha*m^hat/(sqrt(v^hat) + eps) + decay*x_t-1)
|
||||
// x_t = x_t-1 - sched*alpha*m^hat/(sqrt(v^hat) + eps) - sched*decay*x_t-1
|
||||
// x_t = x_t-1*(1-sched*decay) - sched*alpha*m^hat/(sqrt(v^hat) + eps)
|
||||
// x_t = x_t-1*(1-sched*decay) + sched*decay*(-alpha/decay)*m^hat/(sqrt(v^hat) + eps)
|
||||
// x_t = mix(x_t-1, (-alpha/decay)*m^hat/(sqrt(v^hat) + eps), sched*decay)
|
||||
ggml_vec_cpy_f32 (nx, mh, m);
|
||||
ggml_vec_cpy_f32 (nx, vh, v);
|
||||
|
||||
|
@ -14684,6 +14690,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
ggml_vec_acc1_f32 (nx, vh, eps);
|
||||
|
||||
ggml_vec_div_f32 (nx, mh, mh, vh);
|
||||
ggml_vec_scale_f32(nx, x, 1.0f - decay);
|
||||
ggml_vec_sub_f32 (nx, x, x, mh);
|
||||
|
||||
// update the parameters
|
||||
|
|
2
ggml.h
2
ggml.h
|
@ -1055,6 +1055,8 @@ extern "C" {
|
|||
struct {
|
||||
int n_iter;
|
||||
|
||||
float sched; // schedule multiplier (fixed, decay or warmup)
|
||||
float decay; // weight decay for AdamW, use 0.0f to disable
|
||||
float alpha; // learning rate
|
||||
float beta1;
|
||||
float beta2;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue