fix gradient accumulation bug where the same batch was used for each microstep

This commit is contained in:
xaedes 2023-09-06 21:35:21 +02:00
parent 0393116628
commit de6170d818
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 4 additions and 4 deletions

View file

@ -2049,7 +2049,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
invalid_param = true;
break;
}
params->n_gradient_accumulation = std::stoi(argv[i]);
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
} else if (arg == "--norm-rms-eps") {
if (++i >= argc) {
invalid_param = true;
@ -2449,7 +2449,7 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
data->samples_size,
data->tokens_data,
data->tokens_size,
opt->iter,
opt->iter * params->n_gradient_accumulation,
data->tokens_input,
data->target_probs);

View file

@ -1599,7 +1599,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
invalid_param = true;
break;
}
params->n_gradient_accumulation = std::stoi(argv[i]);
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
} else if (arg == "-n" || arg == "--examples") {
if (++i >= argc) {
invalid_param = true;
@ -1846,7 +1846,7 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
data->samples_size,
data->tokens_data,
data->tokens_size,
opt->iter,
opt->iter * params->n_gradient_accumulation,
data->tokens_input,
data->target_logits,
data->target_probs);