diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 6f133ac5f..d205367b3 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -2339,8 +2339,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) { if (save_now) { int new_iters = opt->iter - data->last_save_iter; data->lora->train_its += new_iters; - data->lora->train_samples += new_iters * n_batch; - data->lora->train_tokens += new_iters * n_batch * n_ctx; + data->lora->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + data->lora->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; if (strlen(params->fn_checkpoint_out) > 0) { save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, opt->iter, params->fn_latest); @@ -2779,8 +2779,8 @@ int main(int argc, char ** argv) { int new_iters = opt->iter - opt_cb_data.last_save_iter; if (new_iters > 0) { lora.train_its += new_iters; - lora.train_samples += new_iters * n_batch; - lora.train_tokens += new_iters * n_batch * n_tokens; + lora.train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + lora.train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; if (strlen(params.fn_checkpoint_out) > 0) { save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, opt, params.pattern_fn_it, opt->iter, params.fn_latest); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 21dacfeba..549302a81 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1800,8 +1800,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) { if (save_now) { int new_iters = opt->iter - data->last_save_iter; data->model->train_its += new_iters; - data->model->train_samples += new_iters * n_batch; - data->model->train_tokens += new_iters * n_batch * n_ctx; + data->model->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + data->model->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; if (strlen(params->fn_checkpoint_out) > 0) { save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest); @@ -2122,8 +2122,8 @@ int main(int argc, char ** argv) { int new_iters = opt->iter - opt_cb_data.last_save_iter; model.train_its += new_iters; - model.train_samples += new_iters * n_batch; - model.train_tokens += new_iters * n_batch * n_tokens; + model.train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + model.train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; if (params.n_examples > 0) { save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, opt->iter, params.fn_latest);