diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index bde29c5b0..aaaf954be 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -3386,6 +3386,7 @@ struct train_params { float adam_alpha; float adam_min_alpha; float adam_decay; + int adam_decay_min_ndim; float adam_beta1; float adam_beta2; float adam_gclip; @@ -3446,6 +3447,7 @@ struct train_params get_default_train_params() { params.adam_alpha = 1e-3f; params.adam_min_alpha = 1e-4f; params.adam_decay = 1e-1f; + params.adam_decay_min_ndim = 2; params.adam_beta1 = 0.9f; params.adam_beta2 = 0.999f; params.adam_gclip = 1.0f; @@ -3505,6 +3507,7 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha, usually 0.1 * alpha (default %f)\n", params->adam_min_alpha); fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay); + fprintf(stderr, " --adam-decay-min-ndim N Minimum number of tensor dimensions to apply AdamW weight decay. Weight decay is not applied to tensors with less n_dims. (default %d)\n", params->adam_decay_min_ndim); fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); @@ -3731,6 +3734,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->adam_decay = std::stof(argv[i]); + } else if (arg == "--adam-decay-min-ndim") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->adam_decay_min_ndim = std::stoi(argv[i]); } else if (arg == "--adam-beta1") { if (++i >= argc) { invalid_param = true; @@ -3965,6 +3974,7 @@ int main(int argc, char ** argv) { opt_params_adam.adam.sched = 1.0f; opt_params_adam.adam.alpha = params.adam_alpha; opt_params_adam.adam.decay = params.adam_decay; + opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim; opt_params_adam.adam.beta1 = params.adam_beta1; opt_params_adam.adam.beta2 = params.adam_beta2; opt_params_adam.adam.gclip = params.adam_gclip; diff --git a/ggml.c b/ggml.c index e0f91ed5a..2138cb8bc 100644 --- a/ggml.c +++ b/ggml.c @@ -17316,6 +17316,7 @@ static enum ggml_opt_result ggml_opt_adam( const float beta2 = params.adam.beta2; const float eps = params.adam.eps; const float gclip = params.adam.gclip; + const int decay_min_ndim = params.adam.decay_min_ndim; float * m = opt->adam.m->data; // first moment float * v = opt->adam.v->data; // second moment @@ -17394,7 +17395,7 @@ static enum ggml_opt_result ggml_opt_adam( int64_t i = 0; for (int p = 0; p < np; ++p) { const int64_t ne = ggml_nelements(ps[p]); - const float p_decay = decay * sched; + const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0) * sched; for (int64_t j = 0; j < ne; ++j) { float x = ggml_get_f32_1d(ps[p], j); float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm; @@ -17911,6 +17912,7 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { .n_iter = 10000, .sched = 1.000f, .decay = 0.0f, + .decay_min_ndim = 2, .alpha = 0.001f, .beta1 = 0.9f, .beta2 = 0.999f, diff --git a/ggml.h b/ggml.h index fadc343ee..3980c0050 100644 --- a/ggml.h +++ b/ggml.h @@ -1506,6 +1506,7 @@ extern "C" { float sched; // schedule multiplier (fixed, decay or warmup) float decay; // weight decay for AdamW, use 0.0f to disable + int decay_min_ndim; // minimum number of tensor dimension to apply weight decay float alpha; // learning rate float beta1; float beta2;