cancel optimization when specified number of epochs is completed

This commit is contained in:
xaedes 2023-09-22 21:00:46 +02:00
parent 9145c87acc
commit da05205af6
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
6 changed files with 69 additions and 11 deletions

View file

@ -1044,6 +1044,7 @@ struct train_params_common get_default_train_params_common() {
params.n_threads = 6; params.n_threads = 6;
params.n_batch = 8; params.n_batch = 8;
params.n_gradient_accumulation = 1; params.n_gradient_accumulation = 1;
params.n_epochs = -1;
params.custom_n_ctx = false; params.custom_n_ctx = false;
@ -1122,7 +1123,7 @@ void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train
fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past); fprintf(stderr, " --opt-past N Number of optimization iterations to track for delta convergence test. Disabled when zero. (default %d)\n", params->opt_past);
fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta); fprintf(stderr, " --opt-delta N Maximum delta for delta convergence test. Disabled when <= zero. (default %f)\n", params->opt_delta);
fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement); fprintf(stderr, " --opt-max-no-improvement N Maximum number of optimization iterations with no improvement. Disabled when <= zero. (default %d)\n", params->opt_max_no_improvement);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f); fprintf(stderr, " --epochs N Maximum number epochs to process. (default %d)\n", params->n_epochs);
fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter); fprintf(stderr, " --adam-iter N Maximum number of Adam optimization iterations for each batch (default %d)\n", params->adam_n_iter);
fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha); fprintf(stderr, " --adam-alpha N Adam learning rate alpha (default %f)\n", params->adam_alpha);
fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha); fprintf(stderr, " --adam-min-alpha N Adam minimum learning rate alpha - including warmup phase (default %f)\n", params->adam_min_alpha);
@ -1131,6 +1132,7 @@ void print_common_train_usage(int /*argc*/, char ** /*argv*/, const struct train
fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1); fprintf(stderr, " --adam-beta1 N AdamW beta1 in interval [0,1). How much to smooth the first moment of gradients. (default %f)\n", params->adam_beta1);
fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2); fprintf(stderr, " --adam-beta2 N AdamW beta2 in interval [0,1). How much to smooth the second moment of gradients. (default %f)\n", params->adam_beta2);
fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip); fprintf(stderr, " --adam-gclip N AdamW gradient clipping. Disabled when zero. (default %f)\n", params->adam_gclip);
fprintf(stderr, " --adam-epsf N AdamW epsilon for convergence test. Disabled when <= zero. (default %f)\n", params->adam_eps_f);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
@ -1296,6 +1298,12 @@ bool consume_common_train_arg(
return true; return true;
} }
params->adam_eps_f = std::stof(argv[i]); params->adam_eps_f = std::stof(argv[i]);
} else if (arg == "--epochs") {
if (++i >= argc) {
*invalid_param = true;
return true;
}
params->n_epochs = std::stoi(argv[i]);
} else if (arg == "--adam-iter") { } else if (arg == "--adam-iter") {
if (++i >= argc) { if (++i >= argc) {
*invalid_param = true; *invalid_param = true;
@ -1359,7 +1367,7 @@ void finish_processing_train_args(struct train_params_common * params) {
} }
} }
void train_opt_callback(void * vdata, int accum_step, float * sched) { void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel) {
struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata; struct train_opt_callback_data * data = (struct train_opt_callback_data *) vdata;
struct train_params_common * params = data->params; struct train_params_common * params = data->params;
struct train_state * train = data->train; struct train_state * train = data->train;
@ -1475,4 +1483,14 @@ void train_opt_callback(void * vdata, int accum_step, float * sched) {
data->samples_count); data->samples_count);
train->shuffle_next_sample = 0; train->shuffle_next_sample = 0;
} }
const bool last_epoch_reached = (params->n_epochs > 0 && train->train_epochs - data->first_epoch >= params->n_epochs);
if (last_epoch_reached) {
// allow optimization iteration at last epoch to be completed before canceling
if (data->iter_at_last_epoch < 0) {
data->iter_at_last_epoch = opt->iter;
} else if (opt->iter > data->iter_at_last_epoch) {
*cancel = true;
}
}
} }

View file

@ -43,6 +43,7 @@ struct train_params_common {
int n_threads; int n_threads;
int n_batch; int n_batch;
int n_gradient_accumulation; int n_gradient_accumulation;
int n_epochs;
bool custom_n_ctx; bool custom_n_ctx;
@ -101,6 +102,8 @@ struct train_opt_callback_data {
struct ggml_tensor * tokens_input; struct ggml_tensor * tokens_input;
struct ggml_tensor * target_probs; struct ggml_tensor * target_probs;
int first_iter; int first_iter;
int first_epoch;
int iter_at_last_epoch;
int64_t last_time; int64_t last_time;
double millis_per_iter; double millis_per_iter;
}; };
@ -224,4 +227,4 @@ void save_train_state_gguf(struct gguf_context * fctx, struct train_state * trai
std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration); std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration);
void train_opt_callback(void * vdata, int accum_step, float * sched); void train_opt_callback(void * vdata, int accum_step, float * sched, bool * cancel);

View file

@ -1881,6 +1881,8 @@ int main(int argc, char ** argv) {
opt_cb_data.tokens_input = tokens_input; opt_cb_data.tokens_input = tokens_input;
opt_cb_data.target_probs = target_probs; opt_cb_data.target_probs = target_probs;
opt_cb_data.first_iter = opt->iter; opt_cb_data.first_iter = opt->iter;
opt_cb_data.first_epoch = train->train_epochs;
opt_cb_data.iter_at_last_epoch = -1;
opt_cb_data.last_time = ggml_time_ms(); opt_cb_data.last_time = ggml_time_ms();
opt_cb_data.millis_per_iter = 0.0; opt_cb_data.millis_per_iter = 0.0;

View file

@ -1244,6 +1244,8 @@ int main(int argc, char ** argv) {
opt_cb_data.tokens_input = tokens_input; opt_cb_data.tokens_input = tokens_input;
opt_cb_data.target_probs = target_probs; opt_cb_data.target_probs = target_probs;
opt_cb_data.first_iter = opt->iter; opt_cb_data.first_iter = opt->iter;
opt_cb_data.first_epoch = train->train_epochs;
opt_cb_data.iter_at_last_epoch = -1;
opt_cb_data.last_time = ggml_time_ms(); opt_cb_data.last_time = ggml_time_ms();
opt_cb_data.millis_per_iter = 0.0; opt_cb_data.millis_per_iter = 0.0;

47
ggml.c
View file

@ -19268,14 +19268,17 @@ static enum ggml_opt_result ggml_opt_adam(
struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size);
cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs;
bool cancel = false;
// compute the function value // compute the function value
float fx = 0; float fx = 0;
ggml_set_zero(opt->adam.g); ggml_set_zero(opt->adam.g);
for (int accum_step = 0; accum_step < n_accum; ++accum_step) { for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
if (callback) { if (callback) {
callback(callback_data, accum_step, &sched); callback(callback_data, accum_step, &sched, &cancel);
if (cancel) {
break;
}
} }
// ggml_graph_reset (gf); // ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
@ -19283,6 +19286,9 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_opt_acc_grad(np, ps, g, accum_norm); ggml_opt_acc_grad(np, ps, g, accum_norm);
fx += ggml_get_f32_1d(f, 0); fx += ggml_get_f32_1d(f, 0);
} }
if (cancel) {
return GGML_OPT_DID_NOT_CONVERGE;
}
fx *= accum_norm; fx *= accum_norm;
opt->adam.fx_prev = fx; opt->adam.fx_prev = fx;
@ -19308,6 +19314,9 @@ static enum ggml_opt_result ggml_opt_adam(
// run the optimizer // run the optimizer
for (int t = 0; t < params.adam.n_iter; ++t) { for (int t = 0; t < params.adam.n_iter; ++t) {
if (cancel) {
break;
}
opt->iter = iter0 + t + 1; opt->iter = iter0 + t + 1;
GGML_PRINT_DEBUG ("=== iter %d ===\n", t); GGML_PRINT_DEBUG ("=== iter %d ===\n", t);
@ -19363,7 +19372,10 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_set_zero(opt->adam.g); ggml_set_zero(opt->adam.g);
for (int accum_step = 0; accum_step < n_accum; ++accum_step) { for (int accum_step = 0; accum_step < n_accum; ++accum_step) {
if (callback) { if (callback) {
callback(callback_data, accum_step, &sched); callback(callback_data, accum_step, &sched, &cancel);
if (cancel) {
break;
}
} }
// ggml_graph_reset (gf); // ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
@ -19371,6 +19383,9 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_opt_acc_grad(np, ps, g, accum_norm); ggml_opt_acc_grad(np, ps, g, accum_norm);
fx += ggml_get_f32_1d(f, 0); fx += ggml_get_f32_1d(f, 0);
} }
if (cancel) {
break;
}
fx *= accum_norm; fx *= accum_norm;
opt->loss_after = fx; opt->loss_after = fx;
@ -19456,6 +19471,7 @@ static enum ggml_opt_result linesearch_backtracking(
struct ggml_cplan * cplan, struct ggml_cplan * cplan,
const int np, const int np,
struct ggml_tensor * ps[], struct ggml_tensor * ps[],
bool * cancel,
ggml_opt_callback callback, ggml_opt_callback callback,
void * callback_data) { void * callback_data) {
int count = 0; int count = 0;
@ -19488,7 +19504,7 @@ static enum ggml_opt_result linesearch_backtracking(
finit = *fx; finit = *fx;
dgtest = params->lbfgs.ftol*dginit; dgtest = params->lbfgs.ftol*dginit;
while (true) { while (!*cancel) {
ggml_vec_cpy_f32(nx, x, xp); ggml_vec_cpy_f32(nx, x, xp);
ggml_vec_mad_f32(nx, x, d, *step); ggml_vec_mad_f32(nx, x, d, *step);
@ -19502,7 +19518,10 @@ static enum ggml_opt_result linesearch_backtracking(
if (callback) { if (callback) {
// LBFG-S does not support learning rate -> ignore learning schedule // LBFG-S does not support learning rate -> ignore learning schedule
float sched = 0; float sched = 0;
callback(callback_data, accum_step, &sched); callback(callback_data, accum_step, &sched, cancel);
if (*cancel) {
break;
}
} }
// ggml_graph_reset (gf); // ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
@ -19510,6 +19529,9 @@ static enum ggml_opt_result linesearch_backtracking(
ggml_opt_acc_grad(np, ps, g, accum_norm); ggml_opt_acc_grad(np, ps, g, accum_norm);
*fx += ggml_get_f32_1d(f, 0); *fx += ggml_get_f32_1d(f, 0);
} }
if (*cancel) {
break;
}
*fx *= accum_norm; *fx *= accum_norm;
} }
@ -19628,6 +19650,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
float * lm_s = opt->lbfgs.lms->data; float * lm_s = opt->lbfgs.lms->data;
float * lm_y = opt->lbfgs.lmy->data; float * lm_y = opt->lbfgs.lmy->data;
bool cancel = false;
// evaluate the function value and its gradient // evaluate the function value and its gradient
{ {
ggml_opt_set_params(np, ps, x); ggml_opt_set_params(np, ps, x);
@ -19638,7 +19662,10 @@ static enum ggml_opt_result ggml_opt_lbfgs(
if (callback) { if (callback) {
// LBFG-S does not support learning rate -> ignore learning schedule // LBFG-S does not support learning rate -> ignore learning schedule
float sched = 0; float sched = 0;
callback(callback_data, accum_step, &sched); callback(callback_data, accum_step, &sched, &cancel);
if (cancel) {
break;
}
} }
// ggml_graph_reset (gf); // ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
@ -19646,6 +19673,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_opt_acc_grad(np, ps, g, accum_norm); ggml_opt_acc_grad(np, ps, g, accum_norm);
fx += ggml_get_f32_1d(f, 0); fx += ggml_get_f32_1d(f, 0);
} }
if (cancel) {
return GGML_OPT_DID_NOT_CONVERGE;
}
fx *= accum_norm; fx *= accum_norm;
opt->loss_before = fx; opt->loss_before = fx;
@ -19704,7 +19734,10 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, xp, x);
ggml_vec_cpy_f32(nx, gp, g); ggml_vec_cpy_f32(nx, gp, g);
ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, callback, callback_data); ls = linesearch_backtracking(&params, nx, x, &fx, g, d, step, xp, f, gb, &cplan, np, ps, &cancel, callback, callback_data);
if (!cancel) {
break;
}
if (ls < 0) { if (ls < 0) {
// linesearch failed - go back to the previous point and return // linesearch failed - go back to the previous point and return

2
ggml.h
View file

@ -1726,7 +1726,7 @@ extern "C" {
GGML_LINESEARCH_INVALID_PARAMETERS, GGML_LINESEARCH_INVALID_PARAMETERS,
}; };
typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched); typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel);
// optimization parameters // optimization parameters
// //