fix gradient accumulation bug where the same batch was used for each microstep
This commit is contained in:
parent
0393116628
commit
de6170d818
2 changed files with 4 additions and 4 deletions
|
@ -2049,7 +2049,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_gradient_accumulation = std::stoi(argv[i]);
|
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
|
||||||
} else if (arg == "--norm-rms-eps") {
|
} else if (arg == "--norm-rms-eps") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -2449,7 +2449,7 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
|
||||||
data->samples_size,
|
data->samples_size,
|
||||||
data->tokens_data,
|
data->tokens_data,
|
||||||
data->tokens_size,
|
data->tokens_size,
|
||||||
opt->iter,
|
opt->iter * params->n_gradient_accumulation,
|
||||||
data->tokens_input,
|
data->tokens_input,
|
||||||
data->target_probs);
|
data->target_probs);
|
||||||
|
|
||||||
|
|
|
@ -1599,7 +1599,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_gradient_accumulation = std::stoi(argv[i]);
|
params->n_gradient_accumulation = std::max(1, std::stoi(argv[i]));
|
||||||
} else if (arg == "-n" || arg == "--examples") {
|
} else if (arg == "-n" || arg == "--examples") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1846,7 +1846,7 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
|
||||||
data->samples_size,
|
data->samples_size,
|
||||||
data->tokens_data,
|
data->tokens_data,
|
||||||
data->tokens_size,
|
data->tokens_size,
|
||||||
opt->iter,
|
opt->iter * params->n_gradient_accumulation,
|
||||||
data->tokens_input,
|
data->tokens_input,
|
||||||
data->target_logits,
|
data->target_logits,
|
||||||
data->target_probs);
|
data->target_probs);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue