fix tracking of train_samples and train_tokens

This commit is contained in:
xaedes 2023-09-05 02:18:17 +02:00
parent c1c3b0e0c2
commit d07b6aac77
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 8 additions and 8 deletions

View file

@ -2339,8 +2339,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
if (save_now) {
int new_iters = opt->iter - data->last_save_iter;
data->lora->train_its += new_iters;
data->lora->train_samples += new_iters * n_batch;
data->lora->train_tokens += new_iters * n_batch * n_ctx;
data->lora->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
data->lora->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
if (strlen(params->fn_checkpoint_out) > 0) {
save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, opt, params->pattern_fn_it, opt->iter, params->fn_latest);
@ -2779,8 +2779,8 @@ int main(int argc, char ** argv) {
int new_iters = opt->iter - opt_cb_data.last_save_iter;
if (new_iters > 0) {
lora.train_its += new_iters;
lora.train_samples += new_iters * n_batch;
lora.train_tokens += new_iters * n_batch * n_tokens;
lora.train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
lora.train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
if (strlen(params.fn_checkpoint_out) > 0) {
save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, opt, params.pattern_fn_it, opt->iter, params.fn_latest);

View file

@ -1800,8 +1800,8 @@ void opt_callback(void * vdata, int accum_step, float * sched) {
if (save_now) {
int new_iters = opt->iter - data->last_save_iter;
data->model->train_its += new_iters;
data->model->train_samples += new_iters * n_batch;
data->model->train_tokens += new_iters * n_batch * n_ctx;
data->model->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
data->model->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx;
if (strlen(params->fn_checkpoint_out) > 0) {
save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest);
@ -2122,8 +2122,8 @@ int main(int argc, char ** argv) {
int new_iters = opt->iter - opt_cb_data.last_save_iter;
model.train_its += new_iters;
model.train_samples += new_iters * n_batch;
model.train_tokens += new_iters * n_batch * n_tokens;
model.train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch;
model.train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens;
if (params.n_examples > 0) {
save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, opt->iter, params.fn_latest);