From d07b6aac7790c1cfc0292a91e7c42588ca29b3e0 Mon Sep 17 00:00:00 2001 From: xaedes Date: Tue, 5 Sep 2023 02:18:17 +0200 Subject: [PATCH] fix tracking of train_samples and train_tokens --- examples/finetune/finetune.cpp | 8 ++++---- .../train-text-from-scratch/train-text-from-scratch.cpp | 8 ++++---- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 6f133ac5f..d205367b3 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -2339,8 +2339,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) { if (save_now) { int new_iters = opt->iter - data->last_save_iter; data->lora->train_its += new_iters; - data->lora->train_samples += new_iters * n_batch; - data->lora->train_tokens += new_iters * n_batch * n_ctx; + data->lora->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + data->lora->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; if (strlen(params->fn_checkpoint_out) > 0) { save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, opt->iter, params->fn_latest); @@ -2779,8 +2779,8 @@ int main(int argc, char ** argv) { int new_iters = opt->iter - opt_cb_data.last_save_iter; if (new_iters > 0) { lora.train_its += new_iters; - lora.train_samples += new_iters * n_batch; - lora.train_tokens += new_iters * n_batch * n_tokens; + lora.train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + lora.train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; if (strlen(params.fn_checkpoint_out) > 0) { save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, opt, params.pattern_fn_it, opt->iter, params.fn_latest); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 21dacfeba..549302a81 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1800,8 +1800,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) { if (save_now) { int new_iters = opt->iter - data->last_save_iter; data->model->train_its += new_iters; - data->model->train_samples += new_iters * n_batch; - data->model->train_tokens += new_iters * n_batch * n_ctx; + data->model->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + data->model->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; if (strlen(params->fn_checkpoint_out) > 0) { save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest); @@ -2122,8 +2122,8 @@ int main(int argc, char ** argv) { int new_iters = opt->iter - opt_cb_data.last_save_iter; model.train_its += new_iters; - model.train_samples += new_iters * n_batch; - model.train_tokens += new_iters * n_batch * n_tokens; + model.train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + model.train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; if (params.n_examples > 0) { save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, opt->iter, params.fn_latest);