add more training parameters:

--enable-restart N         Only for Adam optimizer. Enable restarts of cos-decay
--disable-restart N        Only for Adam optimizer. Disable restarts of cos-decay
--opt-past N               Number of optimization iterations to track for delta convergence test. Disabled when zero.
--opt-delta N              Maximum delta for delta convergence test. Disabled when <= zero.
--opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero.
--adam-epsf N              AdamW epsilon for convergence test. Disabled when <= zero.
--adam-min-alpha N         Adam minimum learning rate alpha, usually 0.1 * alpha
This commit is contained in:
xaedes 2023-07-02 21:33:47 +02:00
parent d0fbb7d328
commit c6a18e15c1
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -3333,10 +3333,12 @@ float cosine_decay(const int decay_steps, const float alpha, int step) {
return decay;
}
float cosine_decay_restart(int decay_steps, const float alpha, int step, float restart_step_mult) {
while (step > decay_steps) {
step -= decay_steps;
decay_steps = (int) restart_step_mult * decay_steps;
float cosine_decay_restart(int decay_steps, const float alpha, int step, float restart_step_mult, bool enable_restart) {
if (enable_restart) {
while (step > decay_steps) {
step -= decay_steps;
decay_steps = (int) restart_step_mult * decay_steps;
}
}
return cosine_decay(decay_steps, alpha, step);
}
@ -3376,14 +3378,21 @@ struct train_params {
int cos_decay_steps;
float cos_decay_restart;
float cos_decay_alpha;
bool enable_restart;
int opt_past;
float opt_delta;
int opt_max_no_improvement;
int lbfgs_n_iter;
int adam_n_iter;
float adam_alpha;
float adam_min_alpha;
float adam_decay;
float adam_beta1;
float adam_beta2;
float adam_gclip;
float adam_eps_f;
int mem_model_gb;
int mem_compute_gb;
@ -3424,19 +3433,26 @@ struct train_params get_default_train_params() {
params.use_scratch = true;
params.use_checkpointing = true;
params.opt_past = 0;
params.opt_delta = 1e-5f;
params.opt_max_no_improvement = 0;
// only adam
params.warmup = 100;
params.cos_decay_steps = 1000;
params.cos_decay_restart = 1.1f;
params.cos_decay_alpha = 0.0f;
params.enable_restart = false;
params.lbfgs_n_iter = 16;
params.adam_n_iter = 16;
params.adam_alpha = 1e-3f;
params.adam_min_alpha = 1e-4f;
params.adam_decay = 1e-1f;
params.adam_beta1 = 0.9f;
params.adam_beta2 = 0.999f;
params.adam_gclip = 1.0f;
params.adam_eps_f = 0.0f;
params.mem_model_gb = 2;
params.mem_compute_gb = 24;
@ -3482,13 +3498,20 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
fprintf(stderr, " --cos-decay-alpha N Only for Adam optimizer. Cosine decay alpha (default %f)\n", params->cos_decay_alpha);
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha, usually 0.1 * alpha (default %f)\n", params->adam_min_alpha);
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
@ -3659,12 +3682,34 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break;
}
params->cos_decay_alpha = std::stof(argv[i]);
} else if (arg == "--lbfgs-iter") {
} else if (arg == "--enable-restart") {
params->enable_restart = true;
} else if (arg == "--disable-restart") {
params->enable_restart = false;
} else if (arg == "--opt-past") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->lbfgs_n_iter = std::stoi(argv[i]);
params->opt_past = std::stoi(argv[i]);
} else if (arg == "--opt-delta") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_delta = std::stof(argv[i]);
} else if (arg == "--opt-max-no-improvement") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->opt_max_no_improvement = std::stoi(argv[i]);
} else if (arg == "--adam-epsf") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_eps_f = std::stof(argv[i]);
} else if (arg == "--adam-iter") {
if (++i >= argc) {
invalid_param = true;
@ -3677,6 +3722,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break;
}
params->adam_alpha = std::stof(argv[i]);
} else if (arg == "--adam-min-alpha") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->adam_min_alpha = std::stof(argv[i]);
} else if (arg == "--adam-decay") {
if (++i >= argc) {
invalid_param = true;
@ -3701,6 +3752,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break;
}
params->adam_gclip = std::stof(argv[i]);
} else if (arg == "--lbfgs-iter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->lbfgs_n_iter = std::stoi(argv[i]);
} else if (arg == "--mem-model") {
if (++i >= argc) {
invalid_param = true;
@ -3846,21 +3903,28 @@ int main(int argc, char ** argv) {
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
opt_params_adam.print_forward_graph = false;
opt_params_adam.print_forward_graph = false;
opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = params.n_threads;
opt_params_adam.adam.n_iter = params.adam_n_iter;
opt_params_adam.adam.sched = 1.0f;
opt_params_adam.adam.alpha = params.adam_alpha;
opt_params_adam.adam.decay = params.adam_decay;
opt_params_adam.adam.beta1 = params.adam_beta1;
opt_params_adam.adam.beta2 = params.adam_beta2;
opt_params_adam.adam.gclip = params.adam_gclip;
opt_params_adam.n_threads = params.n_threads;
opt_params_adam.past = params.opt_past;
opt_params_adam.delta = params.opt_delta;
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
opt_params_adam.adam.n_iter = params.adam_n_iter;
opt_params_adam.adam.sched = 1.0f;
opt_params_adam.adam.alpha = params.adam_alpha;
opt_params_adam.adam.decay = params.adam_decay;
opt_params_adam.adam.beta1 = params.adam_beta1;
opt_params_adam.adam.beta2 = params.adam_beta2;
opt_params_adam.adam.gclip = params.adam_gclip;
opt_params_adam.adam.eps_f = params.adam_eps_f;
opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false;
opt_params_lbfgs.n_threads = params.n_threads;
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
opt_params_lbfgs.n_threads = params.n_threads;
opt_params_adam.past = params.opt_past;
opt_params_adam.delta = params.opt_delta;
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
opt->ctx = model.ctx;
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
@ -3996,7 +4060,11 @@ int main(int argc, char ** argv) {
params.cos_decay_steps,
params.cos_decay_alpha,
opt->iter - params.warmup,
params.cos_decay_restart);
params.cos_decay_restart,
params.enable_restart);
float min_sched = params.adam_min_alpha / params.adam_alpha;
opt->params.adam.sched = min_sched + opt->params.adam.sched * (1.0f - min_sched);
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);