increase train_samples by used_samples instead of number of batches

on batch can contain more than one sample when option "fill_with_next_samples" is used
This commit is contained in:
xaedes 2023-09-16 20:23:05 +02:00
parent 48d3509190
commit 571dc94da9
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
3 changed files with 3 additions and 5 deletions

View file

@ -1367,9 +1367,8 @@ void train_opt_callback(void * vdata, int accum_step, float * sched) {
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
if (save_now) {
int new_iters = opt->iter - data->last_save_iter;
train->train_its += new_iters;
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
train->train_its += new_iters;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
if (data->save_cb) {
data->save_cb(data->save_data, train);
@ -1431,6 +1430,7 @@ void train_opt_callback(void * vdata, int accum_step, float * sched) {
params->separate_with_bos,
params->fill_with_next_samples);
train->train_samples += used_samples;
train->shuffle_next_sample += used_samples;
if (train->shuffle_next_sample >= train->shuffle_sample_count) {

View file

@ -1938,7 +1938,6 @@ int main(int argc, char ** argv) {
int new_iters = opt->iter - opt_cb_data.last_save_iter;
if (new_iters > 0) {
train->train_its += new_iters;
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
save_train_files(&save_data, train);

View file

@ -1183,7 +1183,6 @@ int main(int argc, char ** argv) {
int new_iters = opt->iter - opt_cb_data.last_save_iter;
if (new_iters > 0) {
train->train_its += new_iters;
train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
save_train_files(&save_data, train);