add optimization callback to ggml_opt_resume_g

this callback is called before each iteration with custom data and pointer to learning schedule parameter (only used in Adam(W)).

can be used for dynamic learning schedule and setting input data for batches before each iteration
This commit is contained in:
xaedes 2023-07-02 22:15:08 +02:00
parent e843d6e71c
commit bfc3119139
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
3 changed files with 69 additions and 31 deletions

View file

@ -4046,12 +4046,8 @@ int main(int argc, char ** argv) {
ggml_build_backward_expand(ctx0, gf, gb, true);
}
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
size_t used_mem_before_opt = ggml_used_mem(ctx0);
float error_before_opt = ggml_get_f32_1d(loss, 0);
opt->params.adam.sched = (opt->iter < params.warmup)
? (float) opt->iter / (float) params.warmup
: cosine_decay_restart(
@ -4066,7 +4062,7 @@ int main(int argc, char ** argv) {
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
ggml_opt_resume_g(ctx0, opt, loss, gf, gb);
ggml_opt_resume_g(ctx0, opt, loss, gf, gb, NULL, NULL);
size_t used_mem_after_opt = ggml_used_mem(ctx0);
@ -4074,14 +4070,10 @@ int main(int argc, char ** argv) {
model.train_samples += n_batch;
model.train_tokens += n_batch * n_tokens;
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
float error_after_opt = ggml_get_f32_1d(loss, 0);
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
printf("Example %d, opt iter %d\n", ex, opt->iter);
printf("error_before_opt: %.6f\n", error_before_opt);
printf("error_after_opt: %.6f\n", error_after_opt);
printf("error_before_opt: %.6f\n", opt->loss_before);
printf("error_after_opt: %.6f\n", opt->loss_after);
printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
}

71
ggml.c
View file

@ -17281,7 +17281,9 @@ static enum ggml_opt_result ggml_opt_adam(
struct ggml_opt_params params,
struct ggml_tensor * f,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb) {
struct ggml_cgraph * gb,
ggml_opt_callback callback,
void * callback_data) {
GGML_ASSERT(ggml_is_scalar(f));
// these will store the parameters we want to optimize
@ -17307,8 +17309,8 @@ static enum ggml_opt_result ggml_opt_adam(
}
// constants
const float sched = params.adam.sched;
const float alpha = params.adam.alpha * sched;
float sched = params.adam.sched;
const float alpha = params.adam.alpha;
const float decay = params.adam.decay * alpha;
const float beta1 = params.adam.beta1;
const float beta2 = params.adam.beta2;
@ -17320,6 +17322,10 @@ static enum ggml_opt_result ggml_opt_adam(
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
if (callback) {
callback(callback_data, &sched);
}
// compute the function value
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
@ -17332,6 +17338,9 @@ static enum ggml_opt_result ggml_opt_adam(
pf[opt->iter % params.past] = opt->adam.fx_prev;
}
opt->loss_before = opt->adam.fx_prev;
opt->loss_after = opt->adam.fx_prev;
// initialize
if (opt->just_initialized) {
opt->adam.n_no_improvement = 0;
@ -17380,11 +17389,12 @@ static enum ggml_opt_result ggml_opt_adam(
gnorm = (float) ((ggml_float) gclip / norm);
}
}
const float beta1h = alpha/(1.0f - powf(beta1, opt->iter));
const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter));
const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter));
const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter));
int64_t i = 0;
for (int p = 0; p < np; ++p) {
const int64_t ne = ggml_nelements(ps[p]);
const float p_decay = decay * sched;
for (int64_t j = 0; j < ne; ++j) {
float x = ggml_get_f32_1d(ps[p], j);
float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
@ -17393,13 +17403,13 @@ static enum ggml_opt_result ggml_opt_adam(
float mh = m[i]*beta1h;
float vh = v[i]*beta2h;
vh = sqrtf(vh) + eps;
x = x*(1.0f - decay) - mh/vh;
x = x*(1.0f - p_decay) - mh/vh;
ggml_set_f32_1d(ps[p], j, x);
++i;
}
}
}
// {
{
// // update the gradient
// ggml_opt_get_grad(np, ps, g1);
@ -17436,7 +17446,11 @@ static enum ggml_opt_result ggml_opt_adam(
// // update the parameters
// ggml_opt_set_params(np, ps, x);
// }
}
if (callback) {
callback(callback_data, &sched);
}
ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f);
@ -17444,6 +17458,8 @@ static enum ggml_opt_result ggml_opt_adam(
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
const float fx = ggml_get_f32_1d(f, 0);
opt->loss_after = fx;
// check convergence
if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
@ -17525,7 +17541,9 @@ static enum ggml_opt_result linesearch_backtracking(
struct ggml_cgraph * gf,
struct ggml_cgraph * gb,
const int np,
struct ggml_tensor * ps[]) {
struct ggml_tensor * ps[],
ggml_opt_callback callback,
void * callback_data) {
int count = 0;
float width = 0.0f;
@ -17554,6 +17572,12 @@ static enum ggml_opt_result linesearch_backtracking(
dgtest = params->lbfgs.ftol*dginit;
while (true) {
if (callback) {
// LBFG-S does not support learning rate -> ignore learning schedule
float sched = 0;
callback(callback_data, &sched);
}
ggml_vec_cpy_f32(nx, x, xp);
ggml_vec_mad_f32(nx, x, d, *step);
@ -17624,7 +17648,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
struct ggml_opt_params params,
struct ggml_tensor * f,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb) {
struct ggml_cgraph * gb,
ggml_opt_callback callback,
void * callback_data) {
if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) {
@ -17677,6 +17703,12 @@ static enum ggml_opt_result ggml_opt_lbfgs(
float * lm_s = opt->lbfgs.lms->data;
float * lm_y = opt->lbfgs.lmy->data;
if (callback) {
// LBFG-S does not support learning rate -> ignore learning schedule
float sched = 0;
callback(callback_data, &sched);
}
// evaluate the function value and its gradient
{
ggml_opt_set_params(np, ps, x);
@ -17689,6 +17721,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_opt_get_grad(np, ps, g);
fx = ggml_get_f32_1d(f, 0);
opt->loss_before = fx;
opt->loss_after = fx;
}
// search direction = -gradient
@ -17743,7 +17778,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
ggml_vec_cpy_f32(nx, xp, x);
ggml_vec_cpy_f32(nx, gp, g);
ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
ls = linesearch_backtracking(ctx, &params, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps, callback, callback_data);
if (ls < 0) {
// linesearch failed - go back to the previous point and return
@ -17753,6 +17788,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
return ls;
}
opt->loss_after = fx;
ggml_vec_norm_f32(nx, &xnorm, x);
ggml_vec_norm_f32(nx, &gnorm, g);
@ -17810,7 +17847,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
// ys = y^t \cdot s -> 1 / \rho.
// yy = y^t \cdot y.
//
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]);
ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
lm_ys[end[0]] = ys;
@ -18020,7 +18057,7 @@ enum ggml_opt_result ggml_opt_resume(
*gf = ggml_build_forward (f);
*gb = ggml_build_backward(ctx, gf, true);
return ggml_opt_resume_g(ctx, opt, f, gf, gb);
return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
}
enum ggml_opt_result ggml_opt_resume_g(
@ -18028,7 +18065,9 @@ enum ggml_opt_result ggml_opt_resume_g(
struct ggml_opt_context * opt,
struct ggml_tensor * f,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb) {
struct ggml_cgraph * gb,
ggml_opt_callback callback,
void * callback_data) {
// build forward + backward compute graphs
enum ggml_opt_result result = GGML_OPT_OK;
@ -18036,11 +18075,11 @@ enum ggml_opt_result ggml_opt_resume_g(
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data);
} break;
case GGML_OPT_LBFGS:
{
result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data);
} break;
}

15
ggml.h
View file

@ -1469,6 +1469,8 @@ extern "C" {
GGML_LINESEARCH_INVALID_PARAMETERS,
};
typedef void (*ggml_opt_callback)(void * data, float * sched);
// optimization parameters
//
// see ggml.c (ggml_opt_default_params) for default values
@ -1538,6 +1540,9 @@ extern "C" {
bool just_initialized;
float loss_before;
float loss_after;
struct {
struct ggml_tensor * m; // first moment
struct ggml_tensor * v; // second moment
@ -1577,10 +1582,10 @@ extern "C" {
// initialize optimizer context
GGML_API void ggml_opt_init(
struct ggml_context * ctx,
struct ggml_context * ctx,
struct ggml_opt_context * opt,
struct ggml_opt_params params,
int64_t nx);
struct ggml_opt_params params,
int64_t nx);
// continue optimizing the function defined by the tensor f
GGML_API enum ggml_opt_result ggml_opt_resume(
@ -1594,7 +1599,9 @@ extern "C" {
struct ggml_opt_context * opt,
struct ggml_tensor * f,
struct ggml_cgraph * gf,
struct ggml_cgraph * gb);
struct ggml_cgraph * gb,
ggml_opt_callback callback,
void * callback_data);
//
// quantization