update checkpoint train stats before saving via "--save-every"

This commit is contained in:
xaedes 2023-08-23 19:34:45 +02:00
parent 8b4106ae33
commit 77a3092c83
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -2676,9 +2676,15 @@ void opt_callback(void * vdata, float * sched) {
struct train_params * params = data->params; struct train_params * params = data->params;
struct ggml_opt_context * opt = data->opt; struct ggml_opt_context * opt = data->opt;
int n_batch = params->n_batch; int n_batch = params->n_batch;
int n_ctx = params->n_ctx;
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
if (save_now) { if (save_now) {
int new_iters = opt->iter - data->last_save_iter;
data->lora->train_its += new_iters;
data->lora->train_samples += new_iters * n_batch;
data->lora->train_tokens += new_iters * n_batch * n_ctx;
if (strlen(params->fn_checkpoint_out) > 0) { if (strlen(params->fn_checkpoint_out) > 0) {
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, opt->iter, params->fn_latest); save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, opt->iter, params->fn_latest);
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, -1, params->fn_latest); save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, -1, params->fn_latest);
@ -3004,10 +3010,6 @@ int main(int argc, char ** argv) {
size_t used_mem_after_opt = ggml_used_mem(ctx0); size_t used_mem_after_opt = ggml_used_mem(ctx0);
int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter;
lora.train_its = opt->iter;
lora.train_samples += n_batch * n_iter;
lora.train_tokens += n_batch * n_tokens * n_iter;
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) { if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
printf("Example %d, opt iter %d\n", ex, opt->iter); printf("Example %d, opt iter %d\n", ex, opt->iter);
@ -3050,6 +3052,11 @@ int main(int argc, char ** argv) {
double dd = (double) d * 1e-3; double dd = (double) d * 1e-3;
printf("%s: total training time=%f seconds\n", __func__, dd); printf("%s: total training time=%f seconds\n", __func__, dd);
int new_iters = opt->iter - opt_cb_data.last_save_iter;
lora.train_its += new_iters;
lora.train_samples += new_iters * n_batch;
lora.train_tokens += new_iters * n_batch * n_tokens;
if (params.n_examples > 0) { if (params.n_examples > 0) {
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, opt->iter, params.fn_latest); save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, opt->iter, params.fn_latest);
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, -1, params.fn_latest); save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, -1, params.fn_latest);
@ -3060,6 +3067,8 @@ int main(int argc, char ** argv) {
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest); save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest);
} }
opt_cb_data.last_save_iter = opt->iter;
{ {
int n_gen = params.n_predict; int n_gen = params.n_predict;
int sample_ctx = n_tokens - n_tokens/8; int sample_ctx = n_tokens - n_tokens/8;