From de6170d8184137ec13496af5ac1aff77e1d083b2 Mon Sep 17 00:00:00 2001 From: xaedes Date: Wed, 6 Sep 2023 21:35:21 +0200 Subject: [PATCH] fix gradient accumulation bug where the same batch was used for each microstep --- examples/finetune/finetune.cpp | 4 ++-- examples/train-text-from-scratch/train-text-from-scratch.cpp | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 632392970..ed6bd8793 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -2049,7 +2049,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { invalid_param = true; break; } - params->n_gradient_accumulation = std::stoi(argv[i]); + params->n_gradient_accumulation = std::max(1, std::stoi(argv[i])); } else if (arg == "--norm-rms-eps") { if (++i >= argc) { invalid_param = true; @@ -2449,7 +2449,7 @@ void opt_callback(void * vdata, int accum_step, float * sched) { data->samples_size, data->tokens_data, data->tokens_size, - opt->iter, + opt->iter * params->n_gradient_accumulation, data->tokens_input, data->target_probs); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 549302a81..0a486f553 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1599,7 +1599,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { invalid_param = true; break; } - params->n_gradient_accumulation = std::stoi(argv[i]); + params->n_gradient_accumulation = std::max(1, std::stoi(argv[i])); } else if (arg == "-n" || arg == "--examples") { if (++i >= argc) { invalid_param = true; @@ -1846,7 +1846,7 @@ void opt_callback(void * vdata, int accum_step, float * sched) { data->samples_size, data->tokens_data, data->tokens_size, - opt->iter, + opt->iter * params->n_gradient_accumulation, data->tokens_input, data->target_logits, data->target_probs);