fix gradient accumulation bug where the same batch was used for each microstep

This commit is contained in:
xaedes 2023-09-06 22:45:36 +02:00
parent de6170d818
commit 0c2c9c7545
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 2 additions and 2 deletions

View file

@ -2449,7 +2449,7 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
data->samples_size,
data->tokens_data,
data->tokens_size,
opt->iter * params->n_gradient_accumulation,
opt->iter*params->n_gradient_accumulation + accum_step,
data->tokens_input,
data->target_probs);

View file

@ -1846,7 +1846,7 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
data->samples_size,
data->tokens_data,
data->tokens_size,
opt->iter * params->n_gradient_accumulation,
opt->iter*params->n_gradient_accumulation + accum_step,
data->tokens_input,
data->target_logits,
data->target_probs);