update checkpoint train stats before saving via "--save-every"
This commit is contained in:
parent
8b4106ae33
commit
77a3092c83
1 changed files with 15 additions and 6 deletions
|
@ -2676,9 +2676,15 @@ void opt_callback(void * vdata, float * sched) {
|
||||||
struct train_params * params = data->params;
|
struct train_params * params = data->params;
|
||||||
struct ggml_opt_context * opt = data->opt;
|
struct ggml_opt_context * opt = data->opt;
|
||||||
int n_batch = params->n_batch;
|
int n_batch = params->n_batch;
|
||||||
|
int n_ctx = params->n_ctx;
|
||||||
|
|
||||||
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every);
|
||||||
if (save_now) {
|
if (save_now) {
|
||||||
|
int new_iters = opt->iter - data->last_save_iter;
|
||||||
|
data->lora->train_its += new_iters;
|
||||||
|
data->lora->train_samples += new_iters * n_batch;
|
||||||
|
data->lora->train_tokens += new_iters * n_batch * n_ctx;
|
||||||
|
|
||||||
if (strlen(params->fn_checkpoint_out) > 0) {
|
if (strlen(params->fn_checkpoint_out) > 0) {
|
||||||
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, opt->iter, params->fn_latest);
|
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, opt->iter, params->fn_latest);
|
||||||
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, -1, params->fn_latest);
|
save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, -1, params->fn_latest);
|
||||||
|
@ -3004,10 +3010,6 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
size_t used_mem_after_opt = ggml_used_mem(ctx0);
|
size_t used_mem_after_opt = ggml_used_mem(ctx0);
|
||||||
|
|
||||||
int n_iter = params.use_adam ? params.adam_n_iter : params.lbfgs_n_iter;
|
|
||||||
lora.train_its = opt->iter;
|
|
||||||
lora.train_samples += n_batch * n_iter;
|
|
||||||
lora.train_tokens += n_batch * n_tokens * n_iter;
|
|
||||||
|
|
||||||
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
|
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
|
||||||
printf("Example %d, opt iter %d\n", ex, opt->iter);
|
printf("Example %d, opt iter %d\n", ex, opt->iter);
|
||||||
|
@ -3050,6 +3052,11 @@ int main(int argc, char ** argv) {
|
||||||
double dd = (double) d * 1e-3;
|
double dd = (double) d * 1e-3;
|
||||||
printf("%s: total training time=%f seconds\n", __func__, dd);
|
printf("%s: total training time=%f seconds\n", __func__, dd);
|
||||||
|
|
||||||
|
int new_iters = opt->iter - opt_cb_data.last_save_iter;
|
||||||
|
lora.train_its += new_iters;
|
||||||
|
lora.train_samples += new_iters * n_batch;
|
||||||
|
lora.train_tokens += new_iters * n_batch * n_tokens;
|
||||||
|
|
||||||
if (params.n_examples > 0) {
|
if (params.n_examples > 0) {
|
||||||
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, opt->iter, params.fn_latest);
|
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, opt->iter, params.fn_latest);
|
||||||
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, -1, params.fn_latest);
|
save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, -1, params.fn_latest);
|
||||||
|
@ -3060,6 +3067,8 @@ int main(int argc, char ** argv) {
|
||||||
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest);
|
save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
opt_cb_data.last_save_iter = opt->iter;
|
||||||
|
|
||||||
{
|
{
|
||||||
int n_gen = params.n_predict;
|
int n_gen = params.n_predict;
|
||||||
int sample_ctx = n_tokens - n_tokens/8;
|
int sample_ctx = n_tokens - n_tokens/8;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue