add minimum number of tensor dimensions to apply weight decay (default 2)
this allows to not apply weight decay to bias parameters
This commit is contained in:
parent
d7aa4d9576
commit
e6ff0728e0
3 changed files with 14 additions and 1 deletions
|
@ -3386,6 +3386,7 @@ struct train_params {
|
|||
float adam_alpha;
|
||||
float adam_min_alpha;
|
||||
float adam_decay;
|
||||
int adam_decay_min_ndim;
|
||||
float adam_beta1;
|
||||
float adam_beta2;
|
||||
float adam_gclip;
|
||||
|
@ -3446,6 +3447,7 @@ struct train_params get_default_train_params() {
|
|||
params.adam_alpha = 1e-3f;
|
||||
params.adam_min_alpha = 1e-4f;
|
||||
params.adam_decay = 1e-1f;
|
||||
params.adam_decay_min_ndim = 2;
|
||||
params.adam_beta1 = 0.9f;
|
||||
params.adam_beta2 = 0.999f;
|
||||
params.adam_gclip = 1.0f;
|
||||
|
@ -3505,6 +3507,7 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
||||
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha, usually 0.1 * alpha (default %f)\n", params->adam_min_alpha);
|
||||
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
||||
fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim);
|
||||
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
|
||||
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
|
||||
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
|
||||
|
@ -3731,6 +3734,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
break;
|
||||
}
|
||||
params->adam_decay = std::stof(argv[i]);
|
||||
} else if (arg == "--adam-decay-min-ndim") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->adam_decay_min_ndim = std::stoi(argv[i]);
|
||||
} else if (arg == "--adam-beta1") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -3965,6 +3974,7 @@ int main(int argc, char ** argv) {
|
|||
opt_params_adam.adam.sched = 1.0f;
|
||||
opt_params_adam.adam.alpha = params.adam_alpha;
|
||||
opt_params_adam.adam.decay = params.adam_decay;
|
||||
opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim;
|
||||
opt_params_adam.adam.beta1 = params.adam_beta1;
|
||||
opt_params_adam.adam.beta2 = params.adam_beta2;
|
||||
opt_params_adam.adam.gclip = params.adam_gclip;
|
||||
|
|
4
ggml.c
4
ggml.c
|
@ -17316,6 +17316,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
const float beta2 = params.adam.beta2;
|
||||
const float eps = params.adam.eps;
|
||||
const float gclip = params.adam.gclip;
|
||||
const int decay_min_ndim = params.adam.decay_min_ndim;
|
||||
|
||||
float * m = opt->adam.m->data; // first moment
|
||||
float * v = opt->adam.v->data; // second moment
|
||||
|
@ -17394,7 +17395,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
int64_t i = 0;
|
||||
for (int p = 0; p < np; ++p) {
|
||||
const int64_t ne = ggml_nelements(ps[p]);
|
||||
const float p_decay = decay * sched;
|
||||
const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0) * sched;
|
||||
for (int64_t j = 0; j < ne; ++j) {
|
||||
float x = ggml_get_f32_1d(ps[p], j);
|
||||
float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
|
||||
|
@ -17911,6 +17912,7 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) {
|
|||
.n_iter = 10000,
|
||||
.sched = 1.000f,
|
||||
.decay = 0.0f,
|
||||
.decay_min_ndim = 2,
|
||||
.alpha = 0.001f,
|
||||
.beta1 = 0.9f,
|
||||
.beta2 = 0.999f,
|
||||
|
|
1
ggml.h
1
ggml.h
|
@ -1506,6 +1506,7 @@ extern "C" {
|
|||
|
||||
float sched; // schedule multiplier (fixed, decay or warmup)
|
||||
float decay; // weight decay for AdamW, use 0.0f to disable
|
||||
int decay_min_ndim; // minimum number of tensor dimension to apply weight decay
|
||||
float alpha; // learning rate
|
||||
float beta1;
|
||||
float beta2;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue