diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 0f330fd4a..6adbece4c 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -4046,12 +4046,8 @@ int main(int argc, char ** argv) { ggml_build_backward_expand(ctx0, gf, gb, true); } - ggml_graph_compute_helper(work_buffer, gf, params.n_threads); - size_t used_mem_before_opt = ggml_used_mem(ctx0); - float error_before_opt = ggml_get_f32_1d(loss, 0); - opt->params.adam.sched = (opt->iter < params.warmup) ? (float) opt->iter / (float) params.warmup : cosine_decay_restart( @@ -4066,7 +4062,7 @@ int main(int argc, char ** argv) { printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched); - ggml_opt_resume_g(ctx0, opt, loss, gf, gb); + ggml_opt_resume_g(ctx0, opt, loss, gf, gb, NULL, NULL); size_t used_mem_after_opt = ggml_used_mem(ctx0); @@ -4074,14 +4070,10 @@ int main(int argc, char ** argv) { model.train_samples += n_batch; model.train_tokens += n_batch * n_tokens; - ggml_graph_compute_helper(work_buffer, gf, params.n_threads); - - float error_after_opt = ggml_get_f32_1d(loss, 0); - if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { printf("Example %d, opt iter %d\n", ex, opt->iter); - printf("error_before_opt: %.6f\n", error_before_opt); - printf("error_after_opt: %.6f\n", error_after_opt); + printf("error_before_opt: %.6f\n", opt->loss_before); + printf("error_after_opt: %.6f\n", opt->loss_after); printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt); printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt); } diff --git a/ggml.c b/ggml.c index 07d100bf0..e0f91ed5a 100644 --- a/ggml.c +++ b/ggml.c @@ -17281,7 +17281,9 @@ static enum ggml_opt_result ggml_opt_adam( struct ggml_opt_params params, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { GGML_ASSERT(ggml_is_scalar(f)); // these will store the parameters we want to optimize @@ -17307,8 +17309,8 @@ static enum ggml_opt_result ggml_opt_adam( } // constants - const float sched = params.adam.sched; - const float alpha = params.adam.alpha * sched; + float sched = params.adam.sched; + const float alpha = params.adam.alpha; const float decay = params.adam.decay * alpha; const float beta1 = params.adam.beta1; const float beta2 = params.adam.beta2; @@ -17320,6 +17322,10 @@ static enum ggml_opt_result ggml_opt_adam( float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values + if (callback) { + callback(callback_data, &sched); + } + // compute the function value ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); @@ -17332,6 +17338,9 @@ static enum ggml_opt_result ggml_opt_adam( pf[opt->iter % params.past] = opt->adam.fx_prev; } + opt->loss_before = opt->adam.fx_prev; + opt->loss_after = opt->adam.fx_prev; + // initialize if (opt->just_initialized) { opt->adam.n_no_improvement = 0; @@ -17380,11 +17389,12 @@ static enum ggml_opt_result ggml_opt_adam( gnorm = (float) ((ggml_float) gclip / norm); } } - const float beta1h = alpha/(1.0f - powf(beta1, opt->iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); + const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter)); + const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter)); int64_t i = 0; for (int p = 0; p < np; ++p) { const int64_t ne = ggml_nelements(ps[p]); + const float p_decay = decay * sched; for (int64_t j = 0; j < ne; ++j) { float x = ggml_get_f32_1d(ps[p], j); float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm; @@ -17393,13 +17403,13 @@ static enum ggml_opt_result ggml_opt_adam( float mh = m[i]*beta1h; float vh = v[i]*beta2h; vh = sqrtf(vh) + eps; - x = x*(1.0f - decay) - mh/vh; + x = x*(1.0f - p_decay) - mh/vh; ggml_set_f32_1d(ps[p], j, x); ++i; } } } - // { + { // // update the gradient // ggml_opt_get_grad(np, ps, g1); @@ -17436,7 +17446,11 @@ static enum ggml_opt_result ggml_opt_adam( // // update the parameters // ggml_opt_set_params(np, ps, x); - // } + } + + if (callback) { + callback(callback_data, &sched); + } ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); @@ -17444,6 +17458,8 @@ static enum ggml_opt_result ggml_opt_adam( ggml_graph_compute_with_ctx(ctx, gb, params.n_threads); const float fx = ggml_get_f32_1d(f, 0); + opt->loss_after = fx; + // check convergence if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) { @@ -17525,7 +17541,9 @@ static enum ggml_opt_result linesearch_backtracking( struct ggml_cgraph * gf, struct ggml_cgraph * gb, const int np, - struct ggml_tensor * ps[]) { + struct ggml_tensor * ps[], + ggml_opt_callback callback, + void * callback_data) { int count = 0; float width = 0.0f; @@ -17554,6 +17572,12 @@ static enum ggml_opt_result linesearch_backtracking( dgtest = params->lbfgs.ftol*dginit; while (true) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, &sched); + } + ggml_vec_cpy_f32(nx, x, xp); ggml_vec_mad_f32(nx, x, d, *step); @@ -17624,7 +17648,9 @@ static enum ggml_opt_result ggml_opt_lbfgs( struct ggml_opt_params params, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE || params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) { if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) { @@ -17677,6 +17703,12 @@ static enum ggml_opt_result ggml_opt_lbfgs( float * lm_s = opt->lbfgs.lms->data; float * lm_y = opt->lbfgs.lmy->data; + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, &sched); + } + // evaluate the function value and its gradient { ggml_opt_set_params(np, ps, x); @@ -17689,6 +17721,9 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_opt_get_grad(np, ps, g); fx = ggml_get_f32_1d(f, 0); + + opt->loss_before = fx; + opt->loss_after = fx; } // search direction = -gradient @@ -17743,7 +17778,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, gp, g); - ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps); + ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps, callback, callback_data); if (ls < 0) { // linesearch failed - go back to the previous point and return @@ -17753,6 +17788,8 @@ static enum ggml_opt_result ggml_opt_lbfgs( return ls; } + opt->loss_after = fx; + ggml_vec_norm_f32(nx, &xnorm, x); ggml_vec_norm_f32(nx, &gnorm, g); @@ -17810,7 +17847,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( // ys = y^t \cdot s -> 1 / \rho. // yy = y^t \cdot y. // - ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]); + ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]); ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]); lm_ys[end[0]] = ys; @@ -18020,7 +18057,7 @@ enum ggml_opt_result ggml_opt_resume( *gf = ggml_build_forward (f); *gb = ggml_build_backward(ctx, gf, true); - return ggml_opt_resume_g(ctx, opt, f, gf, gb); + return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); } enum ggml_opt_result ggml_opt_resume_g( @@ -18028,7 +18065,9 @@ enum ggml_opt_result ggml_opt_resume_g( struct ggml_opt_context * opt, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb) { + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data) { // build forward + backward compute graphs enum ggml_opt_result result = GGML_OPT_OK; @@ -18036,11 +18075,11 @@ enum ggml_opt_result ggml_opt_resume_g( switch (opt->params.type) { case GGML_OPT_ADAM: { - result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb); + result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data); } break; case GGML_OPT_LBFGS: { - result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb); + result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data); } break; } diff --git a/ggml.h b/ggml.h index 8f51f5d22..fadc343ee 100644 --- a/ggml.h +++ b/ggml.h @@ -1469,6 +1469,8 @@ extern "C" { GGML_LINESEARCH_INVALID_PARAMETERS, }; + typedef void (*ggml_opt_callback)(void * data, float * sched); + // optimization parameters // // see ggml.c (ggml_opt_default_params) for default values @@ -1538,6 +1540,9 @@ extern "C" { bool just_initialized; + float loss_before; + float loss_after; + struct { struct ggml_tensor * m; // first moment struct ggml_tensor * v; // second moment @@ -1577,10 +1582,10 @@ extern "C" { // initialize optimizer context GGML_API void ggml_opt_init( - struct ggml_context * ctx, + struct ggml_context * ctx, struct ggml_opt_context * opt, - struct ggml_opt_params params, - int64_t nx); + struct ggml_opt_params params, + int64_t nx); // continue optimizing the function defined by the tensor f GGML_API enum ggml_opt_result ggml_opt_resume( @@ -1594,7 +1599,9 @@ extern "C" { struct ggml_opt_context * opt, struct ggml_tensor * f, struct ggml_cgraph * gf, - struct ggml_cgraph * gb); + struct ggml_cgraph * gb, + ggml_opt_callback callback, + void * callback_data); // // quantization