From bfc311913991c75cc2d3c2978d9d2273a1370ac6 Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 2 Jul 2023 22:15:08 +0200 Subject: [PATCH] add optimization callback to ggml_opt_resume_g this callback is called before each iteration with custom data and pointer to learning schedule parameter (only used in Adam(W)). can be used for dynamic learning schedule and setting input data for batches before each iteration --- .../train-text-from-scratch.cpp | 14 +--- ggml.c | 71 ++++++++++++++----- ggml.h | 15 ++-- 3 files changed, 69 insertions(+), 31 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 0f330fd4a..6adbece4c 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -4046,12 +4046,8 @@ int main(int argc, char ** argv) { ggml_build_backward_expand(ctx0, gf, gb, true); } - ggml_graph_compute_helper(work_buffer, gf, params.n_threads); - size_t used_mem_before_opt = ggml_used_mem(ctx0); - float error_before_opt = ggml_get_f32_1d(loss, 0); - opt->params.adam.sched = (opt->iter < params.warmup) ? (float) opt->iter / (float) params.warmup : cosine_decay_restart( @@ -4066,7 +4062,7 @@ int main(int argc, char ** argv) { printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); - ggml_opt_resume_g(ctx0, opt, loss, gf, gb); + ggml_opt_resume_g(ctx0, opt, loss, gf, gb, NULL, NULL); size_t used_mem_after_opt = ggml_used_mem(ctx0); @@ -4074,14 +4070,10 @@ int main(int argc, char ** argv) { model.train_samples += n_batch; model.train_tokens += n_batch * n_tokens; - ggml_graph_compute_helper(work_buffer, gf, params.n_threads); - - float error_after_opt = ggml_get_f32_1d(loss, 0); - if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { printf("Example %d, opt iter %d\n", ex, opt->iter); - printf("error_before_opt: %.6f\n", error_before_opt); - printf("error_after_opt: %.6f\n", error_after_opt); + printf("error_before_opt: %.6f\n", opt->loss_before); + printf("error_after_opt: %.6f\n", opt->loss_after); printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt); printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt); } diff --git a/ggml.c b/ggml.c index 07d100bf0..e0f91ed5a 100644 --- a/ggml.c +++ b/ggml.c @@ -17281,7 +17281,9 @@ static enum ggml_opt_result ggml_opt_adam( struct ggml_opt_params params, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { GGML_ASSERT(ggml_is_scalar(f)); // these will store the parameters we want to optimize @@ -17307,8 +17309,8 @@ static enum ggml_opt_result ggml_opt_adam( } // constants - const float sched = params.adam.sched; - const float alpha = params.adam.alpha * sched; + float sched = params.adam.sched; + const float alpha = params.adam.alpha; const float decay = params.adam.decay * alpha; const float beta1 = params.adam.beta1; const float beta2 = params.adam.beta2; @@ -17320,6 +17322,10 @@ static enum ggml_opt_result ggml_opt_adam( float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values + if (callback) { + callback(callback_data, &sched); + } + // compute the function value ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); @@ -17332,6 +17338,9 @@ static enum ggml_opt_result ggml_opt_adam( pf[opt->iter % params.past] = opt->adam.fx_prev; } + opt->loss_before = opt->adam.fx_prev; + opt->loss_after = opt->adam.fx_prev; + // initialize if (opt->just_initialized) { opt->adam.n_no_improvement = 0; @@ -17380,11 +17389,12 @@ static enum ggml_opt_result ggml_opt_adam( gnorm = (float) ((ggml_float) gclip / norm); } } - const float beta1h = alpha/(1.0f - powf(beta1, opt->iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); + const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter)); + const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); int64_t i = 0; for (int p = 0; p < np; ++p) { const int64_t ne = ggml_nelements(ps[p]); + const float p_decay = decay * sched; for (int64_t j = 0; j < ne; ++j) { float x = ggml_get_f32_1d(ps[p], j); float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm; @@ -17393,13 +17403,13 @@ static enum ggml_opt_result ggml_opt_adam( float mh = m[i]*beta1h; float vh = v[i]*beta2h; vh = sqrtf(vh) + eps; - x = x*(1.0f - decay) - mh/vh; + x = x*(1.0f - p_decay) - mh/vh; ggml_set_f32_1d(ps[p], j, x); ++i; } } } - // { + { // // update the gradient // ggml_opt_get_grad(np, ps, g1); @@ -17436,7 +17446,11 @@ static enum ggml_opt_result ggml_opt_adam( // // update the parameters // ggml_opt_set_params(np, ps, x); - // } + } + + if (callback) { + callback(callback_data, &sched); + } ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); @@ -17444,6 +17458,8 @@ static enum ggml_opt_result ggml_opt_adam( ggml_graph_compute_with_ctx(ctx, gb, params.n_threads); const float fx = ggml_get_f32_1d(f, 0); + opt->loss_after = fx; + // check convergence if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { @@ -17525,7 +17541,9 @@ static enum ggml_opt_result linesearch_backtracking( struct ggml_cgraph * gf, struct ggml_cgraph * gb, const int np, - struct ggml_tensor * ps[]) { + struct ggml_tensor * ps[], + ggml_opt_callback callback, + void * callback_data) { int count = 0; float width = 0.0f; @@ -17554,6 +17572,12 @@ static enum ggml_opt_result linesearch_backtracking( dgtest = params->lbfgs.ftol*dginit; while (true) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, &sched); + } + ggml_vec_cpy_f32(nx, x, xp); ggml_vec_mad_f32(nx, x, d, *step); @@ -17624,7 +17648,9 @@ static enum ggml_opt_result ggml_opt_lbfgs( struct ggml_opt_params params, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { @@ -17677,6 +17703,12 @@ static enum ggml_opt_result ggml_opt_lbfgs( float * lm_s = opt->lbfgs.lms->data; float * lm_y = opt->lbfgs.lmy->data; + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, &sched); + } + // evaluate the function value and its gradient { ggml_opt_set_params(np, ps, x); @@ -17689,6 +17721,9 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_opt_get_grad(np, ps, g); fx = ggml_get_f32_1d(f, 0); + + opt->loss_before = fx; + opt->loss_after = fx; } // search direction = -gradient @@ -17743,7 +17778,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, gp, g); - ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps); + ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps, callback, callback_data); if (ls < 0) { // linesearch failed - go back to the previous point and return @@ -17753,6 +17788,8 @@ static enum ggml_opt_result ggml_opt_lbfgs( return ls; } + opt->loss_after = fx; + ggml_vec_norm_f32(nx, &xnorm, x); ggml_vec_norm_f32(nx, &gnorm, g); @@ -17810,7 +17847,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( // ys = y^t \cdot s -> 1 / \rho. // yy = y^t \cdot y. // - ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]); + ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]); ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]); lm_ys[end[0]] = ys; @@ -18020,7 +18057,7 @@ enum ggml_opt_result ggml_opt_resume( *gf = ggml_build_forward (f); *gb = ggml_build_backward(ctx, gf, true); - return ggml_opt_resume_g(ctx, opt, f, gf, gb); + return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); } enum ggml_opt_result ggml_opt_resume_g( @@ -18028,7 +18065,9 @@ enum ggml_opt_result ggml_opt_resume_g( struct ggml_opt_context * opt, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { // build forward + backward compute graphs enum ggml_opt_result result = GGML_OPT_OK; @@ -18036,11 +18075,11 @@ enum ggml_opt_result ggml_opt_resume_g( switch (opt->params.type) { case GGML_OPT_ADAM: { - result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb); + result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data); } break; case GGML_OPT_LBFGS: { - result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb); + result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data); } break; } diff --git a/ggml.h b/ggml.h index 8f51f5d22..fadc343ee 100644 --- a/ggml.h +++ b/ggml.h @@ -1469,6 +1469,8 @@ extern "C" { GGML_LINESEARCH_INVALID_PARAMETERS, }; + typedef void (*ggml_opt_callback)(void * data, float * sched); + // optimization parameters // // see ggml.c (ggml_opt_default_params) for default values @@ -1538,6 +1540,9 @@ extern "C" { bool just_initialized; + float loss_before; + float loss_after; + struct { struct ggml_tensor * m; // first moment struct ggml_tensor * v; // second moment @@ -1577,10 +1582,10 @@ extern "C" { // initialize optimizer context GGML_API void ggml_opt_init( - struct ggml_context * ctx, + struct ggml_context * ctx, struct ggml_opt_context * opt, - struct ggml_opt_params params, - int64_t nx); + struct ggml_opt_params params, + int64_t nx); // continue optimizing the function defined by the tensor f GGML_API enum ggml_opt_result ggml_opt_resume( @@ -1594,7 +1599,9 @@ extern "C" { struct ggml_opt_context * opt, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb); + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data); // // quantization