diff --git a/common/train.cpp b/common/train.cpp index 3e8b1427f..8e27f5d98 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -1044,6 +1044,7 @@ struct train_params_common get_default_train_params_common() { params.n_threads = 6; params.n_batch = 8; params.n_gradient_accumulation = 1; + params.n_epochs = -1; params.custom_n_ctx = false; @@ -1122,7 +1123,7 @@ void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past); fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta); fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement); - fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); + fprintf(stderr, " --epochs N Maximum number epochs to process. (default %d)\n", params->n_epochs); fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha); @@ -1131,6 +1132,7 @@ void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); + fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); fprintf(stderr, "\n"); } @@ -1296,6 +1298,12 @@ bool consume_common_train_arg( return true; } params->adam_eps_f = std::stof(argv[i]); + } else if (arg == "--epochs") { + if (++i >= argc) { + *invalid_param = true; + return true; + } + params->n_epochs = std::stoi(argv[i]); } else if (arg == "--adam-iter") { if (++i >= argc) { *invalid_param = true; @@ -1359,7 +1367,7 @@ void finish_processing_train_args(struct train_params_common * params) { } } -void train_opt_callback(void * vdata, int accum_step, float * sched) { +void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) { struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata; struct train_params_common * params = data->params; struct train_state * train = data->train; @@ -1475,4 +1483,14 @@ void train_opt_callback(void * vdata, int accum_step, float * sched) { data->samples_count); train->shuffle_next_sample = 0; } + + const bool last_epoch_reached = (params->n_epochs > 0 && train->train_epochs - data->first_epoch >= params->n_epochs); + if (last_epoch_reached) { + // allow optimization iteration at last epoch to be completed before canceling + if (data->iter_at_last_epoch < 0) { + data->iter_at_last_epoch = opt->iter; + } else if (opt->iter > data->iter_at_last_epoch) { + *cancel = true; + } + } } diff --git a/common/train.h b/common/train.h index 6ef1f9fc5..42fa704b8 100644 --- a/common/train.h +++ b/common/train.h @@ -43,6 +43,7 @@ struct train_params_common { int n_threads; int n_batch; int n_gradient_accumulation; + int n_epochs; bool custom_n_ctx; @@ -101,6 +102,8 @@ struct train_opt_callback_data { struct ggml_tensor * tokens_input; struct ggml_tensor * target_probs; int first_iter; + int first_epoch; + int iter_at_last_epoch; int64_t last_time; double millis_per_iter; }; @@ -224,4 +227,4 @@ void save_train_state_gguf(struct gguf_context * fctx, struct train_state * trai std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration); -void train_opt_callback(void * vdata, int accum_step, float * sched); +void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel); diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index bb6b14547..8c03b9f53 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1881,6 +1881,8 @@ int main(int argc, char ** argv) { opt_cb_data.tokens_input = tokens_input; opt_cb_data.target_probs = target_probs; opt_cb_data.first_iter = opt->iter; + opt_cb_data.first_epoch = train->train_epochs; + opt_cb_data.iter_at_last_epoch = -1; opt_cb_data.last_time = ggml_time_ms(); opt_cb_data.millis_per_iter = 0.0; diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 56eb816a6..b2b8feb91 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1244,6 +1244,8 @@ int main(int argc, char ** argv) { opt_cb_data.tokens_input = tokens_input; opt_cb_data.target_probs = target_probs; opt_cb_data.first_iter = opt->iter; + opt_cb_data.first_epoch = train->train_epochs; + opt_cb_data.iter_at_last_epoch = -1; opt_cb_data.last_time = ggml_time_ms(); opt_cb_data.millis_per_iter = 0.0; diff --git a/ggml.c b/ggml.c index ec9ea80a4..2f17ef650 100644 --- a/ggml.c +++ b/ggml.c @@ -19268,14 +19268,17 @@ static enum ggml_opt_result ggml_opt_adam( struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; + bool cancel = false; // compute the function value - float fx = 0; ggml_set_zero(opt->adam.g); for (int accum_step = 0; accum_step < n_accum; ++accum_step) { if (callback) { - callback(callback_data, accum_step, &sched); + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + break; + } } // ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); @@ -19283,6 +19286,9 @@ static enum ggml_opt_result ggml_opt_adam( ggml_opt_acc_grad(np, ps, g, accum_norm); fx += ggml_get_f32_1d(f, 0); } + if (cancel) { + return GGML_OPT_DID_NOT_CONVERGE; + } fx *= accum_norm; opt->adam.fx_prev = fx; @@ -19308,6 +19314,9 @@ static enum ggml_opt_result ggml_opt_adam( // run the optimizer for (int t = 0; t < params.adam.n_iter; ++t) { + if (cancel) { + break; + } opt->iter = iter0 + t + 1; GGML_PRINT_DEBUG ("=== iter %d ===\n", t); @@ -19363,7 +19372,10 @@ static enum ggml_opt_result ggml_opt_adam( ggml_set_zero(opt->adam.g); for (int accum_step = 0; accum_step < n_accum; ++accum_step) { if (callback) { - callback(callback_data, accum_step, &sched); + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + break; + } } // ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); @@ -19371,6 +19383,9 @@ static enum ggml_opt_result ggml_opt_adam( ggml_opt_acc_grad(np, ps, g, accum_norm); fx += ggml_get_f32_1d(f, 0); } + if (cancel) { + break; + } fx *= accum_norm; opt->loss_after = fx; @@ -19456,6 +19471,7 @@ static enum ggml_opt_result linesearch_backtracking( struct ggml_cplan * cplan, const int np, struct ggml_tensor * ps[], + bool * cancel, ggml_opt_callback callback, void * callback_data) { int count = 0; @@ -19488,7 +19504,7 @@ static enum ggml_opt_result linesearch_backtracking( finit = *fx; dgtest = params->lbfgs.ftol*dginit; - while (true) { + while (!*cancel) { ggml_vec_cpy_f32(nx, x, xp); ggml_vec_mad_f32(nx, x, d, *step); @@ -19502,7 +19518,10 @@ static enum ggml_opt_result linesearch_backtracking( if (callback) { // LBFG-S does not support learning rate -> ignore learning schedule float sched = 0; - callback(callback_data, accum_step, &sched); + callback(callback_data, accum_step, &sched, cancel); + if (*cancel) { + break; + } } // ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); @@ -19510,6 +19529,9 @@ static enum ggml_opt_result linesearch_backtracking( ggml_opt_acc_grad(np, ps, g, accum_norm); *fx += ggml_get_f32_1d(f, 0); } + if (*cancel) { + break; + } *fx *= accum_norm; } @@ -19628,6 +19650,8 @@ static enum ggml_opt_result ggml_opt_lbfgs( float * lm_s = opt->lbfgs.lms->data; float * lm_y = opt->lbfgs.lmy->data; + bool cancel = false; + // evaluate the function value and its gradient { ggml_opt_set_params(np, ps, x); @@ -19638,7 +19662,10 @@ static enum ggml_opt_result ggml_opt_lbfgs( if (callback) { // LBFG-S does not support learning rate -> ignore learning schedule float sched = 0; - callback(callback_data, accum_step, &sched); + callback(callback_data, accum_step, &sched, &cancel); + if (cancel) { + break; + } } // ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); @@ -19646,6 +19673,9 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_opt_acc_grad(np, ps, g, accum_norm); fx += ggml_get_f32_1d(f, 0); } + if (cancel) { + return GGML_OPT_DID_NOT_CONVERGE; + } fx *= accum_norm; opt->loss_before = fx; @@ -19704,7 +19734,10 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, gp, g); - ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, callback, callback_data); + ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data); + if (!cancel) { + break; + } if (ls < 0) { // linesearch failed - go back to the previous point and return diff --git a/ggml.h b/ggml.h index 50b849eb8..43d655ffc 100644 --- a/ggml.h +++ b/ggml.h @@ -1726,7 +1726,7 @@ extern "C" { GGML_LINESEARCH_INVALID_PARAMETERS, }; - typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched); + typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); // optimization parameters //