add optimization callback to ggml_opt_resume_g
this callback is called before each iteration with custom data and pointer to learning schedule parameter (only used in Adam(W)). can be used for dynamic learning schedule and setting input data for batches before each iteration
This commit is contained in:
parent
e843d6e71c
commit
bfc3119139
3 changed files with 69 additions and 31 deletions
|
@ -4046,12 +4046,8 @@ int main(int argc, char ** argv) {
|
|||
ggml_build_backward_expand(ctx0, gf, gb, true);
|
||||
}
|
||||
|
||||
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
|
||||
|
||||
size_t used_mem_before_opt = ggml_used_mem(ctx0);
|
||||
|
||||
float error_before_opt = ggml_get_f32_1d(loss, 0);
|
||||
|
||||
opt->params.adam.sched = (opt->iter < params.warmup)
|
||||
? (float) opt->iter / (float) params.warmup
|
||||
: cosine_decay_restart(
|
||||
|
@ -4066,7 +4062,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
|
||||
|
||||
ggml_opt_resume_g(ctx0, opt, loss, gf, gb);
|
||||
ggml_opt_resume_g(ctx0, opt, loss, gf, gb, NULL, NULL);
|
||||
|
||||
size_t used_mem_after_opt = ggml_used_mem(ctx0);
|
||||
|
||||
|
@ -4074,14 +4070,10 @@ int main(int argc, char ** argv) {
|
|||
model.train_samples += n_batch;
|
||||
model.train_tokens += n_batch * n_tokens;
|
||||
|
||||
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
|
||||
|
||||
float error_after_opt = ggml_get_f32_1d(loss, 0);
|
||||
|
||||
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
|
||||
printf("Example %d, opt iter %d\n", ex, opt->iter);
|
||||
printf("error_before_opt: %.6f\n", error_before_opt);
|
||||
printf("error_after_opt: %.6f\n", error_after_opt);
|
||||
printf("error_before_opt: %.6f\n", opt->loss_before);
|
||||
printf("error_after_opt: %.6f\n", opt->loss_after);
|
||||
printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
|
||||
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
|
||||
}
|
||||
|
|
71
ggml.c
71
ggml.c
|
@ -17281,7 +17281,9 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
struct ggml_opt_params params,
|
||||
struct ggml_tensor * f,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb) {
|
||||
struct ggml_cgraph * gb,
|
||||
ggml_opt_callback callback,
|
||||
void * callback_data) {
|
||||
GGML_ASSERT(ggml_is_scalar(f));
|
||||
|
||||
// these will store the parameters we want to optimize
|
||||
|
@ -17307,8 +17309,8 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
}
|
||||
|
||||
// constants
|
||||
const float sched = params.adam.sched;
|
||||
const float alpha = params.adam.alpha * sched;
|
||||
float sched = params.adam.sched;
|
||||
const float alpha = params.adam.alpha;
|
||||
const float decay = params.adam.decay * alpha;
|
||||
const float beta1 = params.adam.beta1;
|
||||
const float beta2 = params.adam.beta2;
|
||||
|
@ -17320,6 +17322,10 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
|
||||
float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values
|
||||
|
||||
if (callback) {
|
||||
callback(callback_data, &sched);
|
||||
}
|
||||
|
||||
// compute the function value
|
||||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
@ -17332,6 +17338,9 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
pf[opt->iter % params.past] = opt->adam.fx_prev;
|
||||
}
|
||||
|
||||
opt->loss_before = opt->adam.fx_prev;
|
||||
opt->loss_after = opt->adam.fx_prev;
|
||||
|
||||
// initialize
|
||||
if (opt->just_initialized) {
|
||||
opt->adam.n_no_improvement = 0;
|
||||
|
@ -17380,11 +17389,12 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
gnorm = (float) ((ggml_float) gclip / norm);
|
||||
}
|
||||
}
|
||||
const float beta1h = alpha/(1.0f - powf(beta1, opt->iter));
|
||||
const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter));
|
||||
const float beta1h = alpha*sched/(1.0f - powf(beta1, opt->iter));
|
||||
const float beta2h = 1.0f/(1.0f - powf(beta2, opt->iter));
|
||||
int64_t i = 0;
|
||||
for (int p = 0; p < np; ++p) {
|
||||
const int64_t ne = ggml_nelements(ps[p]);
|
||||
const float p_decay = decay * sched;
|
||||
for (int64_t j = 0; j < ne; ++j) {
|
||||
float x = ggml_get_f32_1d(ps[p], j);
|
||||
float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm;
|
||||
|
@ -17393,13 +17403,13 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
float mh = m[i]*beta1h;
|
||||
float vh = v[i]*beta2h;
|
||||
vh = sqrtf(vh) + eps;
|
||||
x = x*(1.0f - decay) - mh/vh;
|
||||
x = x*(1.0f - p_decay) - mh/vh;
|
||||
ggml_set_f32_1d(ps[p], j, x);
|
||||
++i;
|
||||
}
|
||||
}
|
||||
}
|
||||
// {
|
||||
{
|
||||
// // update the gradient
|
||||
// ggml_opt_get_grad(np, ps, g1);
|
||||
|
||||
|
@ -17436,7 +17446,11 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
|
||||
// // update the parameters
|
||||
// ggml_opt_set_params(np, ps, x);
|
||||
// }
|
||||
}
|
||||
|
||||
if (callback) {
|
||||
callback(callback_data, &sched);
|
||||
}
|
||||
|
||||
ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
@ -17444,6 +17458,8 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
ggml_graph_compute_with_ctx(ctx, gb, params.n_threads);
|
||||
|
||||
const float fx = ggml_get_f32_1d(f, 0);
|
||||
opt->loss_after = fx;
|
||||
|
||||
|
||||
// check convergence
|
||||
if (fabsf(fx - fx_prev[0])/fx < params.adam.eps_f) {
|
||||
|
@ -17525,7 +17541,9 @@ static enum ggml_opt_result linesearch_backtracking(
|
|||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb,
|
||||
const int np,
|
||||
struct ggml_tensor * ps[]) {
|
||||
struct ggml_tensor * ps[],
|
||||
ggml_opt_callback callback,
|
||||
void * callback_data) {
|
||||
int count = 0;
|
||||
|
||||
float width = 0.0f;
|
||||
|
@ -17554,6 +17572,12 @@ static enum ggml_opt_result linesearch_backtracking(
|
|||
dgtest = params->lbfgs.ftol*dginit;
|
||||
|
||||
while (true) {
|
||||
if (callback) {
|
||||
// LBFG-S does not support learning rate -> ignore learning schedule
|
||||
float sched = 0;
|
||||
callback(callback_data, &sched);
|
||||
}
|
||||
|
||||
ggml_vec_cpy_f32(nx, x, xp);
|
||||
ggml_vec_mad_f32(nx, x, d, *step);
|
||||
|
||||
|
@ -17624,7 +17648,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
struct ggml_opt_params params,
|
||||
struct ggml_tensor * f,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb) {
|
||||
struct ggml_cgraph * gb,
|
||||
ggml_opt_callback callback,
|
||||
void * callback_data) {
|
||||
if (params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_WOLFE ||
|
||||
params.lbfgs.linesearch == GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) {
|
||||
if (params.lbfgs.wolfe <= params.lbfgs.ftol || 1.f <= params.lbfgs.wolfe) {
|
||||
|
@ -17677,6 +17703,12 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
float * lm_s = opt->lbfgs.lms->data;
|
||||
float * lm_y = opt->lbfgs.lmy->data;
|
||||
|
||||
if (callback) {
|
||||
// LBFG-S does not support learning rate -> ignore learning schedule
|
||||
float sched = 0;
|
||||
callback(callback_data, &sched);
|
||||
}
|
||||
|
||||
// evaluate the function value and its gradient
|
||||
{
|
||||
ggml_opt_set_params(np, ps, x);
|
||||
|
@ -17689,6 +17721,9 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
ggml_opt_get_grad(np, ps, g);
|
||||
|
||||
fx = ggml_get_f32_1d(f, 0);
|
||||
|
||||
opt->loss_before = fx;
|
||||
opt->loss_after = fx;
|
||||
}
|
||||
|
||||
// search direction = -gradient
|
||||
|
@ -17743,7 +17778,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
ggml_vec_cpy_f32(nx, xp, x);
|
||||
ggml_vec_cpy_f32(nx, gp, g);
|
||||
|
||||
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps);
|
||||
ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps, callback, callback_data);
|
||||
|
||||
if (ls < 0) {
|
||||
// linesearch failed - go back to the previous point and return
|
||||
|
@ -17753,6 +17788,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
return ls;
|
||||
}
|
||||
|
||||
opt->loss_after = fx;
|
||||
|
||||
ggml_vec_norm_f32(nx, &xnorm, x);
|
||||
ggml_vec_norm_f32(nx, &gnorm, g);
|
||||
|
||||
|
@ -17810,7 +17847,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
// ys = y^t \cdot s -> 1 / \rho.
|
||||
// yy = y^t \cdot y.
|
||||
//
|
||||
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0] *nx]);
|
||||
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]);
|
||||
ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
|
||||
|
||||
lm_ys[end[0]] = ys;
|
||||
|
@ -18020,7 +18057,7 @@ enum ggml_opt_result ggml_opt_resume(
|
|||
*gf = ggml_build_forward (f);
|
||||
*gb = ggml_build_backward(ctx, gf, true);
|
||||
|
||||
return ggml_opt_resume_g(ctx, opt, f, gf, gb);
|
||||
return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
|
||||
}
|
||||
|
||||
enum ggml_opt_result ggml_opt_resume_g(
|
||||
|
@ -18028,7 +18065,9 @@ enum ggml_opt_result ggml_opt_resume_g(
|
|||
struct ggml_opt_context * opt,
|
||||
struct ggml_tensor * f,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb) {
|
||||
struct ggml_cgraph * gb,
|
||||
ggml_opt_callback callback,
|
||||
void * callback_data) {
|
||||
|
||||
// build forward + backward compute graphs
|
||||
enum ggml_opt_result result = GGML_OPT_OK;
|
||||
|
@ -18036,11 +18075,11 @@ enum ggml_opt_result ggml_opt_resume_g(
|
|||
switch (opt->params.type) {
|
||||
case GGML_OPT_ADAM:
|
||||
{
|
||||
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb);
|
||||
result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb, callback, callback_data);
|
||||
} break;
|
||||
case GGML_OPT_LBFGS:
|
||||
{
|
||||
result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb);
|
||||
result = ggml_opt_lbfgs(ctx, opt, opt->params, f, gf, gb, callback, callback_data);
|
||||
} break;
|
||||
}
|
||||
|
||||
|
|
15
ggml.h
15
ggml.h
|
@ -1469,6 +1469,8 @@ extern "C" {
|
|||
GGML_LINESEARCH_INVALID_PARAMETERS,
|
||||
};
|
||||
|
||||
typedef void (*ggml_opt_callback)(void * data, float * sched);
|
||||
|
||||
// optimization parameters
|
||||
//
|
||||
// see ggml.c (ggml_opt_default_params) for default values
|
||||
|
@ -1538,6 +1540,9 @@ extern "C" {
|
|||
|
||||
bool just_initialized;
|
||||
|
||||
float loss_before;
|
||||
float loss_after;
|
||||
|
||||
struct {
|
||||
struct ggml_tensor * m; // first moment
|
||||
struct ggml_tensor * v; // second moment
|
||||
|
@ -1577,10 +1582,10 @@ extern "C" {
|
|||
|
||||
// initialize optimizer context
|
||||
GGML_API void ggml_opt_init(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_opt_context * opt,
|
||||
struct ggml_opt_params params,
|
||||
int64_t nx);
|
||||
struct ggml_opt_params params,
|
||||
int64_t nx);
|
||||
|
||||
// continue optimizing the function defined by the tensor f
|
||||
GGML_API enum ggml_opt_result ggml_opt_resume(
|
||||
|
@ -1594,7 +1599,9 @@ extern "C" {
|
|||
struct ggml_opt_context * opt,
|
||||
struct ggml_tensor * f,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb);
|
||||
struct ggml_cgraph * gb,
|
||||
ggml_opt_callback callback,
|
||||
void * callback_data);
|
||||
|
||||
//
|
||||
// quantization
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue