add more training parameters:
--enable-restart N Only for Adam optimizer. Enable restarts of cos-decay --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. --adam-min-alpha N Adam minimum learning rate alpha, usually 0.1 * alpha
This commit is contained in:
parent
d0fbb7d328
commit
c6a18e15c1
1 changed files with 88 additions and 20 deletions
|
@ -3333,10 +3333,12 @@ float cosine_decay(const int decay_steps, const float alpha, int step) {
|
|||
return decay;
|
||||
}
|
||||
|
||||
float cosine_decay_restart(int decay_steps, const float alpha, int step, float restart_step_mult) {
|
||||
while (step > decay_steps) {
|
||||
step -= decay_steps;
|
||||
decay_steps = (int) restart_step_mult * decay_steps;
|
||||
float cosine_decay_restart(int decay_steps, const float alpha, int step, float restart_step_mult, bool enable_restart) {
|
||||
if (enable_restart) {
|
||||
while (step > decay_steps) {
|
||||
step -= decay_steps;
|
||||
decay_steps = (int) restart_step_mult * decay_steps;
|
||||
}
|
||||
}
|
||||
return cosine_decay(decay_steps, alpha, step);
|
||||
}
|
||||
|
@ -3376,14 +3378,21 @@ struct train_params {
|
|||
int cos_decay_steps;
|
||||
float cos_decay_restart;
|
||||
float cos_decay_alpha;
|
||||
bool enable_restart;
|
||||
|
||||
int opt_past;
|
||||
float opt_delta;
|
||||
int opt_max_no_improvement;
|
||||
|
||||
int lbfgs_n_iter;
|
||||
int adam_n_iter;
|
||||
float adam_alpha;
|
||||
float adam_min_alpha;
|
||||
float adam_decay;
|
||||
float adam_beta1;
|
||||
float adam_beta2;
|
||||
float adam_gclip;
|
||||
float adam_eps_f;
|
||||
|
||||
int mem_model_gb;
|
||||
int mem_compute_gb;
|
||||
|
@ -3424,19 +3433,26 @@ struct train_params get_default_train_params() {
|
|||
params.use_scratch = true;
|
||||
params.use_checkpointing = true;
|
||||
|
||||
params.opt_past = 0;
|
||||
params.opt_delta = 1e-5f;
|
||||
params.opt_max_no_improvement = 0;
|
||||
|
||||
// only adam
|
||||
params.warmup = 100;
|
||||
params.cos_decay_steps = 1000;
|
||||
params.cos_decay_restart = 1.1f;
|
||||
params.cos_decay_alpha = 0.0f;
|
||||
params.enable_restart = false;
|
||||
|
||||
params.lbfgs_n_iter = 16;
|
||||
params.adam_n_iter = 16;
|
||||
params.adam_alpha = 1e-3f;
|
||||
params.adam_min_alpha = 1e-4f;
|
||||
params.adam_decay = 1e-1f;
|
||||
params.adam_beta1 = 0.9f;
|
||||
params.adam_beta2 = 0.999f;
|
||||
params.adam_gclip = 1.0f;
|
||||
params.adam_eps_f = 0.0f;
|
||||
|
||||
params.mem_model_gb = 2;
|
||||
params.mem_compute_gb = 24;
|
||||
|
@ -3482,13 +3498,20 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " --cos-decay-steps N Only for Adam optimizer. Number of cosine decay steps (default %d)\n", params->cos_decay_steps);
|
||||
fprintf(stderr, " --cos-decay-restart N Only for Adam optimizer. Increase of cosine decay steps after restart (default %f)\n", params->cos_decay_restart);
|
||||
fprintf(stderr, " --cos-decay-alpha N Only for Adam optimizer. Cosine decay alpha (default %f)\n", params->cos_decay_alpha);
|
||||
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
|
||||
fprintf(stderr, " --enable-restart N Only for Adam optimizer. Enable restarts of cos-decay %s\n", params->enable_restart ? "(default)" : "");
|
||||
fprintf(stderr, " --disable-restart N Only for Adam optimizer. Disable restarts of cos-decay %s\n", !params->enable_restart ? "(default)" : "");
|
||||
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
|
||||
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
|
||||
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
|
||||
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
|
||||
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
|
||||
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
|
||||
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha, usually 0.1 * alpha (default %f)\n", params->adam_min_alpha);
|
||||
fprintf(stderr, " --adam-decay N AdamW weight decay. Values greater zero enable AdamW instead of regular Adam. (default %f)\n", params->adam_decay);
|
||||
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
|
||||
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
|
||||
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
|
||||
fprintf(stderr, " --lbfgs-iter N Maximum number of LBFGS optimization iterations for each batch (default %d)\n", params->lbfgs_n_iter);
|
||||
fprintf(stderr, " --mem-model N Memory to allocate for model and cache in gigabytes. (default %d)\n", params->mem_model_gb);
|
||||
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
||||
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
||||
|
@ -3659,12 +3682,34 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
break;
|
||||
}
|
||||
params->cos_decay_alpha = std::stof(argv[i]);
|
||||
} else if (arg == "--lbfgs-iter") {
|
||||
} else if (arg == "--enable-restart") {
|
||||
params->enable_restart = true;
|
||||
} else if (arg == "--disable-restart") {
|
||||
params->enable_restart = false;
|
||||
} else if (arg == "--opt-past") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->lbfgs_n_iter = std::stoi(argv[i]);
|
||||
params->opt_past = std::stoi(argv[i]);
|
||||
} else if (arg == "--opt-delta") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->opt_delta = std::stof(argv[i]);
|
||||
} else if (arg == "--opt-max-no-improvement") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->opt_max_no_improvement = std::stoi(argv[i]);
|
||||
} else if (arg == "--adam-epsf") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->adam_eps_f = std::stof(argv[i]);
|
||||
} else if (arg == "--adam-iter") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -3677,6 +3722,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
break;
|
||||
}
|
||||
params->adam_alpha = std::stof(argv[i]);
|
||||
} else if (arg == "--adam-min-alpha") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->adam_min_alpha = std::stof(argv[i]);
|
||||
} else if (arg == "--adam-decay") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -3701,6 +3752,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
break;
|
||||
}
|
||||
params->adam_gclip = std::stof(argv[i]);
|
||||
} else if (arg == "--lbfgs-iter") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->lbfgs_n_iter = std::stoi(argv[i]);
|
||||
} else if (arg == "--mem-model") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -3846,21 +3903,28 @@ int main(int argc, char ** argv) {
|
|||
|
||||
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
|
||||
struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS);
|
||||
opt_params_adam.print_forward_graph = false;
|
||||
opt_params_adam.print_forward_graph = false;
|
||||
opt_params_adam.print_backward_graph = false;
|
||||
opt_params_adam.n_threads = params.n_threads;
|
||||
opt_params_adam.adam.n_iter = params.adam_n_iter;
|
||||
opt_params_adam.adam.sched = 1.0f;
|
||||
opt_params_adam.adam.alpha = params.adam_alpha;
|
||||
opt_params_adam.adam.decay = params.adam_decay;
|
||||
opt_params_adam.adam.beta1 = params.adam_beta1;
|
||||
opt_params_adam.adam.beta2 = params.adam_beta2;
|
||||
opt_params_adam.adam.gclip = params.adam_gclip;
|
||||
opt_params_adam.n_threads = params.n_threads;
|
||||
opt_params_adam.past = params.opt_past;
|
||||
opt_params_adam.delta = params.opt_delta;
|
||||
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
|
||||
opt_params_adam.adam.n_iter = params.adam_n_iter;
|
||||
opt_params_adam.adam.sched = 1.0f;
|
||||
opt_params_adam.adam.alpha = params.adam_alpha;
|
||||
opt_params_adam.adam.decay = params.adam_decay;
|
||||
opt_params_adam.adam.beta1 = params.adam_beta1;
|
||||
opt_params_adam.adam.beta2 = params.adam_beta2;
|
||||
opt_params_adam.adam.gclip = params.adam_gclip;
|
||||
opt_params_adam.adam.eps_f = params.adam_eps_f;
|
||||
|
||||
opt_params_lbfgs.print_forward_graph = false;
|
||||
opt_params_lbfgs.print_forward_graph = false;
|
||||
opt_params_lbfgs.print_backward_graph = false;
|
||||
opt_params_lbfgs.n_threads = params.n_threads;
|
||||
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
|
||||
opt_params_lbfgs.n_threads = params.n_threads;
|
||||
opt_params_adam.past = params.opt_past;
|
||||
opt_params_adam.delta = params.opt_delta;
|
||||
opt_params_adam.max_no_improvement = params.opt_max_no_improvement;
|
||||
opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter;
|
||||
|
||||
opt->ctx = model.ctx;
|
||||
opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs;
|
||||
|
@ -3996,7 +4060,11 @@ int main(int argc, char ** argv) {
|
|||
params.cos_decay_steps,
|
||||
params.cos_decay_alpha,
|
||||
opt->iter - params.warmup,
|
||||
params.cos_decay_restart);
|
||||
params.cos_decay_restart,
|
||||
params.enable_restart);
|
||||
|
||||
float min_sched = params.adam_min_alpha / params.adam_alpha;
|
||||
opt->params.adam.sched = min_sched + opt->params.adam.sched * (1.0f - min_sched);
|
||||
|
||||
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue