update checkpoint train stats before saving via "--save-every"
This commit is contained in:
parent
8b4106ae33
commit
77a3092c83
1 changed files with 15 additions and 6 deletions
|
@ -2676,17 +2676,23 @@ void opt_callback(void * vdata, float * sched) {
|
|||
struct train_params * params = data->params;
|
||||
struct ggml_opt_context * opt = data->opt;
|
||||
int n_batch = params->n_batch;
|
||||
int n_ctx = params->n_ctx;
|
||||
|
||||
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
||||
if (save_now) {
|
||||
int new_iters = opt->iter - data->last_save_iter;
|
||||
data->lora->train_its += new_iters;
|
||||
data->lora->train_samples += new_iters * n_batch;
|
||||
data->lora->train_tokens += new_iters * n_batch * n_ctx;
|
||||
|
||||
if (strlen(params->fn_checkpoint_out) > 0) {
|
||||
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, -1, params->fn_latest);
|
||||
}
|
||||
}
|
||||
if (strlen(params->fn_lora_out) > 0) {
|
||||
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||
save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, -1, params->fn_latest);
|
||||
}
|
||||
}
|
||||
data->last_save_iter = opt->iter;
|
||||
}
|
||||
|
||||
|
@ -3004,10 +3010,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
size_t used_mem_after_opt = ggml_used_mem(ctx0);
|
||||
|
||||
int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter;
|
||||
lora.train_its = opt->iter;
|
||||
lora.train_samples += n_batch * n_iter;
|
||||
lora.train_tokens += n_batch * n_tokens * n_iter;
|
||||
|
||||
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
|
||||
printf("Example %d, opt iter %d\n", ex, opt->iter);
|
||||
|
@ -3050,6 +3052,11 @@ int main(int argc, char ** argv) {
|
|||
double dd = (double) d * 1e-3;
|
||||
printf("%s: total training time=%f seconds\n", __func__, dd);
|
||||
|
||||
int new_iters = opt->iter - opt_cb_data.last_save_iter;
|
||||
lora.train_its += new_iters;
|
||||
lora.train_samples += new_iters * n_batch;
|
||||
lora.train_tokens += new_iters * n_batch * n_tokens;
|
||||
|
||||
if (params.n_examples > 0) {
|
||||
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, -1, params.fn_latest);
|
||||
|
@ -3060,6 +3067,8 @@ int main(int argc, char ** argv) {
|
|||
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest);
|
||||
}
|
||||
|
||||
opt_cb_data.last_save_iter = opt->iter;
|
||||
|
||||
{
|
||||
int n_gen = params.n_predict;
|
||||
int sample_ctx = n_tokens - n_tokens/8;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue