From 77a3092c830d58a9dfac26aa0445dac57f7b4f3e Mon Sep 17 00:00:00 2001 From: xaedes Date: Wed, 23 Aug 2023 19:34:45 +0200 Subject: [PATCH] update checkpoint train stats before saving via "--save-every" --- examples/finetune/finetune.cpp | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index eaceb71df..f4beb59e2 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -2676,17 +2676,23 @@ void opt_callback(void * vdata, float * sched) { struct train_params * params = data->params; struct ggml_opt_context * opt = data->opt; int n_batch = params->n_batch; + int n_ctx = params->n_ctx; const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); if (save_now) { + int new_iters = opt->iter - data->last_save_iter; + data->lora->train_its += new_iters; + data->lora->train_samples += new_iters * n_batch; + data->lora->train_tokens += new_iters * n_batch * n_ctx; + if (strlen(params->fn_checkpoint_out) > 0) { save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, opt->iter, params->fn_latest); save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, -1, params->fn_latest); - } + } if (strlen(params->fn_lora_out) > 0) { save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest); save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, -1, params->fn_latest); - } + } data->last_save_iter = opt->iter; } @@ -3004,10 +3010,6 @@ int main(int argc, char ** argv) { size_t used_mem_after_opt = ggml_used_mem(ctx0); - int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter; - lora.train_its = opt->iter; - lora.train_samples += n_batch * n_iter; - lora.train_tokens += n_batch * n_tokens * n_iter; if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { printf("Example %d, opt iter %d\n", ex, opt->iter); @@ -3050,6 +3052,11 @@ int main(int argc, char ** argv) { double dd = (double) d * 1e-3; printf("%s: total training time=%f seconds\n", __func__, dd); + int new_iters = opt->iter - opt_cb_data.last_save_iter; + lora.train_its += new_iters; + lora.train_samples += new_iters * n_batch; + lora.train_tokens += new_iters * n_batch * n_tokens; + if (params.n_examples > 0) { save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, opt->iter, params.fn_latest); save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, -1, params.fn_latest); @@ -3060,6 +3067,8 @@ int main(int argc, char ** argv) { save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest); } + opt_cb_data.last_save_iter = opt->iter; + { int n_gen = params.n_predict; int sample_ctx = n_tokens - n_tokens/8;