diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index bba1bdb7c..6f133ac5f 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1730,6 +1730,7 @@ struct train_params { int n_ctx; int n_threads; int n_batch; + int n_gradient_accumulation; bool custom_n_ctx; @@ -1804,6 +1805,7 @@ struct train_params get_default_train_params() { params.n_ctx = 128; params.n_threads = 6; params.n_batch = 8; + params.n_gradient_accumulation = 1; params.custom_n_ctx = false; @@ -1880,6 +1882,7 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); + fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation); fprintf(stderr, " --norm-rms-eps F RMS-Norm epsilon value (default %f)\n", params->f_norm_rms_eps); fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base); fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale); @@ -2015,6 +2018,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->n_batch = std::stoi(argv[i]); + } else if (arg == "--grad-acc") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_gradient_accumulation = std::stoi(argv[i]); } else if (arg == "--norm-rms-eps") { if (++i >= argc) { invalid_param = true; @@ -2299,83 +2308,85 @@ void print_duration(float fmillis) { printf("%02lld:%02lld:%02lld", hours, minutes, seconds); } -void opt_callback(void * vdata, float * sched) { +void opt_callback(void * vdata, int accum_step, float * sched) { struct opt_callback_data * data = (struct opt_callback_data *) vdata; struct train_params * params = data->params; struct ggml_opt_context * opt = data->opt; int n_batch = params->n_batch; int n_ctx = params->n_ctx; - int64_t now = ggml_time_ms(); - if (now > data->last_time) { - float dt = now - data->last_time; - if (data->time_per_iter == 0) { - data->time_per_iter = dt; - } else { - const float gain = 0.7f; - data->time_per_iter = data->time_per_iter*(1.0f-gain) + dt*gain; + if (accum_step == 0) { + int64_t now = ggml_time_ms(); + if (now > data->last_time) { + float dt = now - data->last_time; + if (data->time_per_iter == 0) { + data->time_per_iter = dt; + } else { + const float gain = 0.7f; + data->time_per_iter = data->time_per_iter*(1.0f-gain) + dt*gain; + } } - } - data->last_time = now; - float remaining_time = 0; - if (data->time_per_iter > 0) { - const int n_iter = params->use_adam ? params->adam_n_iter : params->lbfgs_n_iter; - const int done_iter = opt->iter - data->first_iter; - const int remaining_iter = n_iter - done_iter; - remaining_time = remaining_iter * data->time_per_iter; - } - - const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); - if (save_now) { - int new_iters = opt->iter - data->last_save_iter; - data->lora->train_its += new_iters; - data->lora->train_samples += new_iters * n_batch; - data->lora->train_tokens += new_iters * n_batch * n_ctx; - - if (strlen(params->fn_checkpoint_out) > 0) { - save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, opt->iter, params->fn_latest); - save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, -1, params->fn_latest); + data->last_time = now; + float remaining_time = 0; + if (data->time_per_iter > 0) { + const int n_iter = params->use_adam ? params->adam_n_iter : params->lbfgs_n_iter; + const int done_iter = opt->iter - data->first_iter; + const int remaining_iter = n_iter - done_iter; + remaining_time = remaining_iter * data->time_per_iter; } - if (strlen(params->fn_lora_out) > 0) { - save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest); - save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, -1, params->fn_latest); + + const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); + if (save_now) { + int new_iters = opt->iter - data->last_save_iter; + data->lora->train_its += new_iters; + data->lora->train_samples += new_iters * n_batch; + data->lora->train_tokens += new_iters * n_batch * n_ctx; + + if (strlen(params->fn_checkpoint_out) > 0) { + save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, opt->iter, params->fn_latest); + save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, -1, params->fn_latest); + } + if (strlen(params->fn_lora_out) > 0) { + save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest); + save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, -1, params->fn_latest); + } + data->last_save_iter = opt->iter; } - data->last_save_iter = opt->iter; - } - *sched = (opt->iter < params->warmup) - ? (float) opt->iter / (float) params->warmup - : cosine_decay_restart( - params->cos_decay_steps, - params->cos_decay_min, - opt->iter - params->warmup, - params->cos_decay_restart, - params->enable_restart); - float min_sched = params->adam_min_alpha / params->adam_alpha; - *sched = min_sched + *sched * (1.0f - min_sched); + *sched = (opt->iter < params->warmup) + ? (float) opt->iter / (float) params->warmup + : cosine_decay_restart( + params->cos_decay_steps, + params->cos_decay_min, + opt->iter - params->warmup, + params->cos_decay_restart, + params->enable_restart); + float min_sched = params->adam_min_alpha / params->adam_alpha; + *sched = min_sched + *sched * (1.0f - min_sched); - int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); - if (impr_plot > 0) impr_plot = 0; - if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0; - printf("%s: iter=%*d sched=%f loss=%f", - __func__, 6, opt->iter, *sched, opt->loss_after); - if (data->time_per_iter > 0) { - printf(" dt="); - print_duration(data->time_per_iter); - printf(" eta="); - print_duration(remaining_time); - } + int impr_plot = -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); + if (impr_plot > 0) impr_plot = 0; + if (std::isnan(opt->loss_before) || std::isnan(opt->loss_before)) impr_plot = 0; + printf("%s: iter=%*d sched=%f loss=%f", + __func__, 6, opt->iter, *sched, opt->loss_after); + if (data->time_per_iter > 0) { + printf(" dt="); + print_duration(data->time_per_iter); + printf(" eta="); + print_duration(remaining_time); + } - float improvement = opt->loss_before - opt->loss_after; - const float plot_scale = 10.0f; - int bar_len = (int)(1 + improvement*plot_scale + 0.5); - printf(" |"); - for (int i=0; iloss_before - opt->loss_after; + const float plot_scale = 10.0f; + int bar_len = (int)(1 + improvement*plot_scale + 0.5); + printf(" |"); + for (int i=0; i"); + // printf("improvement: %*d>", impr_plot, (int)0); + printf("\n"); } - printf(">"); - // printf("improvement: %*d>", impr_plot, (int)0); - printf("\n"); if (data->shuffle_countdown < n_batch) { printf("%s: reshuffle samples\n", __func__); @@ -2491,30 +2502,32 @@ int main(int argc, char ** argv) { // set opt params from command line if (params.use_adam) { opt->params = ggml_opt_default_params(GGML_OPT_ADAM); - opt->params.print_forward_graph = false; - opt->params.print_backward_graph = false; - opt->params.n_threads = params.n_threads; - opt->params.past = params.opt_past; - opt->params.delta = params.opt_delta; - opt->params.max_no_improvement = params.opt_max_no_improvement; - opt->params.adam.n_iter = params.adam_n_iter; - opt->params.adam.sched = 1.0f; - opt->params.adam.alpha = params.adam_alpha; - opt->params.adam.decay = params.adam_decay; - opt->params.adam.decay_min_ndim = params.adam_decay_min_ndim; - opt->params.adam.beta1 = params.adam_beta1; - opt->params.adam.beta2 = params.adam_beta2; - opt->params.adam.gclip = params.adam_gclip; - opt->params.adam.eps_f = params.adam_eps_f; + opt->params.print_forward_graph = false; + opt->params.print_backward_graph = false; + opt->params.n_threads = params.n_threads; + opt->params.past = params.opt_past; + opt->params.delta = params.opt_delta; + opt->params.max_no_improvement = params.opt_max_no_improvement; + opt->params.n_gradient_accumulation = params.n_gradient_accumulation; + opt->params.adam.n_iter = params.adam_n_iter; + opt->params.adam.sched = 1.0f; + opt->params.adam.alpha = params.adam_alpha; + opt->params.adam.decay = params.adam_decay; + opt->params.adam.decay_min_ndim = params.adam_decay_min_ndim; + opt->params.adam.beta1 = params.adam_beta1; + opt->params.adam.beta2 = params.adam_beta2; + opt->params.adam.gclip = params.adam_gclip; + opt->params.adam.eps_f = params.adam_eps_f; } else { opt->params = ggml_opt_default_params(GGML_OPT_LBFGS); - opt->params.print_forward_graph = false; - opt->params.print_backward_graph = false; - opt->params.n_threads = params.n_threads; - opt->params.past = params.opt_past; - opt->params.delta = params.opt_delta; - opt->params.max_no_improvement = params.opt_max_no_improvement; - opt->params.lbfgs.n_iter = params.lbfgs_n_iter; + opt->params.print_forward_graph = false; + opt->params.print_backward_graph = false; + opt->params.n_threads = params.n_threads; + opt->params.past = params.opt_past; + opt->params.delta = params.opt_delta; + opt->params.max_no_improvement = params.opt_max_no_improvement; + opt->params.n_gradient_accumulation = params.n_gradient_accumulation; + opt->params.lbfgs.n_iter = params.lbfgs_n_iter; } ggml_allocr * alloc = NULL; diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index f31427a99..21dacfeba 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1299,8 +1299,9 @@ struct train_params { int n_ff; int n_threads; - int n_batch; int n_examples; + int n_batch; + int n_gradient_accumulation; float f_norm_rms_eps; float rope_freq_base; @@ -1362,8 +1363,9 @@ struct train_params get_default_train_params() { params.n_ff = 768; params.n_threads = 6; - params.n_batch = 8; params.n_examples = 1; + params.n_batch = 8; + params.n_gradient_accumulation = 1; params.f_norm_rms_eps = 1e-5f; params.rope_freq_base = 10000.0f; @@ -1428,8 +1430,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --rope-freq-base F Frequency base for ROPE (default %f)\n", params->rope_freq_base); fprintf(stderr, " --rope-freq-scale F Frequency scale for ROPE (default %f)\n", params->rope_freq_scale); fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); - fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); fprintf(stderr, " -n N, --examples N Number of examples to train (default %d)\n", params->n_examples); + fprintf(stderr, " -b N, --batch N Parallel batch size (default %d)\n", params->n_batch); + fprintf(stderr, " --grad-acc N Number of gradient accumulation steps (simulates larger batch size of batch*gradacc) (default %d)\n", params->n_gradient_accumulation); fprintf(stderr, " --print-info-interval N Print infos during training each N examples (default %d)\n", params->print_info_interval); fprintf(stderr, " --samples-after-nl Training samples start after newlines. (default %s)\n", params->samples_start_after_nl ? "on" : "off"); fprintf(stderr, " --use-lbfgs Use LBFGS optimizer instead of default Adam\n"); @@ -1591,6 +1594,12 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->n_batch = std::stoi(argv[i]); + } else if (arg == "--grad-acc") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->n_gradient_accumulation = std::stoi(argv[i]); } else if (arg == "-n" || arg == "--examples") { if (++i >= argc) { invalid_param = true; @@ -1779,46 +1788,49 @@ struct opt_callback_data { struct ggml_tensor * target_probs; }; -void opt_callback(void * vdata, float * sched) { +void opt_callback(void * vdata, int accum_step, float * sched) { struct opt_callback_data * data = (struct opt_callback_data *) vdata; struct train_params * params = data->params; struct ggml_opt_context * opt = data->opt; int n_batch = params->n_batch; int n_ctx = params->n_ctx; - const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); - if (save_now) { - int new_iters = opt->iter - data->last_save_iter; - data->model->train_its += new_iters; - data->model->train_samples += new_iters * n_batch; - data->model->train_tokens += new_iters * n_batch * n_ctx; + if (accum_step == 0) { + const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); + if (save_now) { + int new_iters = opt->iter - data->last_save_iter; + data->model->train_its += new_iters; + data->model->train_samples += new_iters * n_batch; + data->model->train_tokens += new_iters * n_batch * n_ctx; - if (strlen(params->fn_checkpoint_out) > 0) { - save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest); - save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, -1, params->fn_latest); + if (strlen(params->fn_checkpoint_out) > 0) { + save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest); + save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, -1, params->fn_latest); + } + if (strlen(params->fn_model_out) > 0) { + save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, opt->iter, params->fn_latest); + save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, -1, params->fn_latest); + } + data->last_save_iter = opt->iter; } - if (strlen(params->fn_model_out) > 0) { - save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, opt->iter, params->fn_latest); - save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, -1, params->fn_latest); - } - data->last_save_iter = opt->iter; + + *sched = (opt->iter < params->warmup) + ? (float) opt->iter / (float) params->warmup + : cosine_decay_restart( + params->cos_decay_steps, + params->cos_decay_min, + opt->iter - params->warmup, + params->cos_decay_restart, + params->enable_restart); + float min_sched = params->adam_min_alpha / params->adam_alpha; + *sched = min_sched + *sched * (1.0f - min_sched); + + int impr_plot = std::isnan(opt->loss_after) ? 0 : -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); + printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0); + } - *sched = (opt->iter < params->warmup) - ? (float) opt->iter / (float) params->warmup - : cosine_decay_restart( - params->cos_decay_steps, - params->cos_decay_min, - opt->iter - params->warmup, - params->cos_decay_restart, - params->enable_restart); - float min_sched = params->adam_min_alpha / params->adam_alpha; - *sched = min_sched + *sched * (1.0f - min_sched); - - int impr_plot = std::isnan(opt->loss_after) ? 0 : -(int)(1 + (opt->loss_before - opt->loss_after) * 10.0f + 0.5f); - printf("%s: iter=%*d, sched=%f loss0=%f loss=%f | improvement: %*d>\n", __func__, 6, opt->iter, *sched, opt->loss_before, opt->loss_after, impr_plot, (int)0); - if (data->shuffle_countdown < n_batch) { printf("%s: reshuffle samples\n", __func__); shuffle_ints(data->samples_data, data->samples_data + data->samples_size); @@ -1917,29 +1929,31 @@ int main(int argc, char ** argv) { struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM); struct ggml_opt_params opt_params_lbfgs = ggml_opt_default_params(GGML_OPT_LBFGS); - opt_params_adam.print_forward_graph = false; - opt_params_adam.print_backward_graph = false; - opt_params_adam.n_threads = params.n_threads; - opt_params_adam.past = params.opt_past; - opt_params_adam.delta = params.opt_delta; - opt_params_adam.max_no_improvement = params.opt_max_no_improvement; - opt_params_adam.adam.n_iter = params.adam_n_iter; - opt_params_adam.adam.sched = 1.0f; - opt_params_adam.adam.alpha = params.adam_alpha; - opt_params_adam.adam.decay = params.adam_decay; - opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim; - opt_params_adam.adam.beta1 = params.adam_beta1; - opt_params_adam.adam.beta2 = params.adam_beta2; - opt_params_adam.adam.gclip = params.adam_gclip; - opt_params_adam.adam.eps_f = params.adam_eps_f; + opt_params_adam.print_forward_graph = false; + opt_params_adam.print_backward_graph = false; + opt_params_adam.n_threads = params.n_threads; + opt_params_adam.past = params.opt_past; + opt_params_adam.delta = params.opt_delta; + opt_params_adam.max_no_improvement = params.opt_max_no_improvement; + opt_params_adam.n_gradient_accumulation = params.n_gradient_accumulation; + opt_params_adam.adam.n_iter = params.adam_n_iter; + opt_params_adam.adam.sched = 1.0f; + opt_params_adam.adam.alpha = params.adam_alpha; + opt_params_adam.adam.decay = params.adam_decay; + opt_params_adam.adam.decay_min_ndim = params.adam_decay_min_ndim; + opt_params_adam.adam.beta1 = params.adam_beta1; + opt_params_adam.adam.beta2 = params.adam_beta2; + opt_params_adam.adam.gclip = params.adam_gclip; + opt_params_adam.adam.eps_f = params.adam_eps_f; - opt_params_lbfgs.print_forward_graph = false; - opt_params_lbfgs.print_backward_graph = false; - opt_params_lbfgs.n_threads = params.n_threads; - opt_params_adam.past = params.opt_past; - opt_params_adam.delta = params.opt_delta; - opt_params_adam.max_no_improvement = params.opt_max_no_improvement; - opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter; + opt_params_lbfgs.print_forward_graph = false; + opt_params_lbfgs.print_backward_graph = false; + opt_params_lbfgs.n_threads = params.n_threads; + opt_params_lbfgs.past = params.opt_past; + opt_params_lbfgs.delta = params.opt_delta; + opt_params_lbfgs.max_no_improvement = params.opt_max_no_improvement; + opt_params_lbfgs.n_gradient_accumulation = params.n_gradient_accumulation; + opt_params_lbfgs.lbfgs.n_iter = params.lbfgs_n_iter; opt->ctx = model.ctx; opt->params = params.use_adam ? opt_params_adam : opt_params_lbfgs; diff --git a/ggml.c b/ggml.c index f5419c34e..d7dc3cb44 100644 --- a/ggml.c +++ b/ggml.c @@ -19112,7 +19112,7 @@ static void ggml_opt_get_params(int np, struct ggml_tensor * const ps[], float * } static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g) { - int i = 0; + int64_t i = 0; for (int p = 0; p < np; ++p) { const int64_t ne = ggml_nelements(ps[p]) ; // TODO: add function to get all elements at once @@ -19122,6 +19122,17 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g } } +static void ggml_opt_acc_grad(int np, struct ggml_tensor * const ps[], float * g, float scale) { + int64_t i = 0; + for (int p = 0; p < np; ++p) { + const int64_t ne = ggml_nelements(ps[p]) ; + // TODO: add function to get all elements at once + for (int64_t j = 0; j < ne; ++j) { + g[i++] += ggml_get_f32_1d(ps[p]->grad, j) * scale; + } + } +} + // // ADAM // @@ -19170,26 +19181,37 @@ static enum ggml_opt_result ggml_opt_adam( const float eps = params.adam.eps; const float gclip = params.adam.gclip; const int decay_min_ndim = params.adam.decay_min_ndim; + const int n_accum = MAX(1, params.n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + float * g = opt->adam.g->data; // gradients float * m = opt->adam.m->data; // first moment float * v = opt->adam.v->data; // second moment float * pf = params.past > 0 ? opt->adam.pf->data : NULL; // past function values - if (callback) { - callback(callback_data, &sched); - } - - // compute the function value - // ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); struct ggml_object * obj = ggml_new_object(ctx, GGML_OBJECT_WORK_BUFFER, cplan.work_size); cplan.work_data = (uint8_t *)ctx->mem_buffer + obj->offs; - ggml_graph_compute(gb, &cplan); - opt->adam.fx_prev = ggml_get_f32_1d(f, 0); + + // compute the function value + + float fx = 0; + ggml_set_zero(opt->adam.g); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + callback(callback_data, accum_step, &sched); + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); + } + fx *= accum_norm; + + opt->adam.fx_prev = fx; opt->adam.fx_best = opt->adam.fx_prev; if (pf) { pf[opt->iter % params.past] = opt->adam.fx_prev; @@ -19234,12 +19256,8 @@ static enum ggml_opt_result ggml_opt_adam( if (gclip > 0.0f) { // gradient clipping ggml_float sum = 0.0; - for (int p = 0; p < np; ++p) { - const int64_t ne = ggml_nelements(ps[p]); - for (int64_t j = 0; j < ne; ++j) { - float g = ggml_get_f32_1d(ps[p]->grad, j); - sum += (ggml_float)(g*g); - } + for (int64_t i = 0; i < nx; ++i) { + sum += (ggml_float)(g[i]*g[i]); } ggml_float norm = sqrt(sum); if (norm > (ggml_float) gclip) { @@ -19253,10 +19271,10 @@ static enum ggml_opt_result ggml_opt_adam( const int64_t ne = ggml_nelements(ps[p]); const float p_decay = ((ps[p]->n_dims >= decay_min_ndim) ? decay : 0.0f) * sched; for (int64_t j = 0; j < ne; ++j) { - float x = ggml_get_f32_1d(ps[p], j); - float g = ggml_get_f32_1d(ps[p]->grad, j)*gnorm; - m[i] = m[i]*beta1 + g*(1.0f - beta1); - v[i] = v[i]*beta2 + g*g*(1.0f - beta2); + float x = ggml_get_f32_1d(ps[p], j); + float g_ = g[i]*gnorm; + m[i] = m[i]*beta1 + g_*(1.0f - beta1); + v[i] = v[i]*beta2 + g_*g_*(1.0f - beta2); float mh = m[i]*beta1h; float vh = v[i]*beta2h; vh = sqrtf(vh) + eps; @@ -19267,16 +19285,20 @@ static enum ggml_opt_result ggml_opt_adam( } } - if (callback) { - callback(callback_data, &sched); + fx = 0; + ggml_set_zero(opt->adam.g); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + callback(callback_data, accum_step, &sched); + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); } + fx *= accum_norm; - // ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - - ggml_graph_compute(gb, &cplan); - - const float fx = ggml_get_f32_1d(f, 0); opt->loss_after = fx; @@ -19373,6 +19395,9 @@ static enum ggml_opt_result linesearch_backtracking( const float dec = 0.5f; const float inc = 2.1f; + const int n_accum = MAX(1, params->n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + if (*step <= 0.f) { return GGML_LINESEARCH_INVALID_PARAMETERS; } @@ -19390,12 +19415,6 @@ static enum ggml_opt_result linesearch_backtracking( dgtest = params->lbfgs.ftol*dginit; while (true) { - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, &sched); - } - ggml_vec_cpy_f32(nx, x, xp); ggml_vec_mad_f32(nx, x, d, *step); @@ -19403,14 +19422,22 @@ static enum ggml_opt_result linesearch_backtracking( { ggml_opt_set_params(np, ps, x); - //ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); + *fx = 0; + memset(g, 0, sizeof(float)*nx); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, accum_step, &sched); + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + *fx += ggml_get_f32_1d(f, 0); + } + *fx *= accum_norm; - ggml_graph_compute(gb, cplan); - - ggml_opt_get_grad(np, ps, g); - - *fx = ggml_get_f32_1d(f, 0); } ++count; @@ -19512,6 +19539,9 @@ static enum ggml_opt_result ggml_opt_lbfgs( float * pf = params.past > 0 ? opt->lbfgs.pf->data : NULL; // past function values + const int n_accum = MAX(1, params.n_gradient_accumulation); + const float accum_norm = 1.0f / (float) n_accum; + float fx = 0.0f; // cost function value float xnorm = 0.0f; // ||x|| float gnorm = 0.0f; // ||g|| @@ -19525,24 +19555,25 @@ static enum ggml_opt_result ggml_opt_lbfgs( float * lm_s = opt->lbfgs.lms->data; float * lm_y = opt->lbfgs.lmy->data; - if (callback) { - // LBFG-S does not support learning rate -> ignore learning schedule - float sched = 0; - callback(callback_data, &sched); - } - // evaluate the function value and its gradient { ggml_opt_set_params(np, ps, x); - //ggml_graph_reset (gf); - ggml_set_f32 (f->grad, 1.0f); - - ggml_graph_compute(gb, &cplan); - - ggml_opt_get_grad(np, ps, g); - - fx = ggml_get_f32_1d(f, 0); + fx = 0; + memset(g, 0, sizeof(float)*nx); + for (int accum_step = 0; accum_step < n_accum; ++accum_step) { + if (callback) { + // LBFG-S does not support learning rate -> ignore learning schedule + float sched = 0; + callback(callback_data, accum_step, &sched); + } + // ggml_graph_reset (gf); + ggml_set_f32 (f->grad, 1.0f); + ggml_graph_compute(gb, &cplan); + ggml_opt_acc_grad(np, ps, g, accum_norm); + fx += ggml_get_f32_1d(f, 0); + } + fx *= accum_norm; opt->loss_before = fx; opt->loss_after = fx; @@ -19729,6 +19760,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { .print_forward_graph = true, .print_backward_graph = true, + .n_gradient_accumulation = 1, + .adam = { .n_iter = 10000, .sched = 1.000f, @@ -19757,6 +19790,8 @@ struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type) { .print_forward_graph = true, .print_backward_graph = true, + .n_gradient_accumulation = 1, + .lbfgs = { .m = 6, .n_iter = 100, @@ -19790,7 +19825,7 @@ GGML_API void ggml_opt_init( if (opt->ctx == NULL) { struct ggml_init_params ctx_opt_params; if (opt->params.type == GGML_OPT_ADAM) { - ctx_opt_params.mem_size = GGML_MEM_ALIGN*2 + ggml_tensor_overhead()*2 + ggml_type_size(GGML_TYPE_F32)*nx*2; + ctx_opt_params.mem_size = GGML_MEM_ALIGN*3 + ggml_tensor_overhead()*3 + ggml_type_size(GGML_TYPE_F32)*nx*3; if (opt->params.past > 0) { ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past; } @@ -19808,6 +19843,7 @@ GGML_API void ggml_opt_init( switch (opt->params.type) { case GGML_OPT_ADAM: { + opt->adam.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx); opt->adam.pf = params.past > 0 diff --git a/ggml.h b/ggml.h index 0ab8cedea..0d38a7110 100644 --- a/ggml.h +++ b/ggml.h @@ -1708,7 +1708,7 @@ extern "C" { GGML_LINESEARCH_INVALID_PARAMETERS, }; - typedef void (*ggml_opt_callback)(void * data, float * sched); + typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched); // optimization parameters // @@ -1739,6 +1739,8 @@ extern "C" { bool print_forward_graph; bool print_backward_graph; + int n_gradient_accumulation; + // ADAM parameters struct { int n_iter; @@ -1784,6 +1786,7 @@ extern "C" { float loss_after; struct { + struct ggml_tensor * g; // current gradient struct ggml_tensor * m; // first moment struct ggml_tensor * v; // second moment struct ggml_tensor * pf; // past function values