diff --git a/common/train.cpp b/common/train.cpp index c2b3f036b..81039e5eb 100644 --- a/common/train.cpp +++ b/common/train.cpp @@ -1001,3 +1001,8 @@ size_t tokenize_file( return out_tokens.size(); } + +std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration) { + std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); + return replace_str(filename, pattern_it, sit.c_str()); +} diff --git a/common/train.h b/common/train.h index 54edd0f4a..59004a87c 100644 --- a/common/train.h +++ b/common/train.h @@ -133,3 +133,6 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * bool load_train_state_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct train_state * train); void save_train_state_gguf(struct gguf_context * fctx, struct train_state * train); +std::string get_train_filename(const char * filename, const char * pattern_it, const char * latest, int64_t iteration); + +typedef void (*save_train_files_callback)(void * data, struct train_state * train); diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 58e96f186..5c787e94e 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1016,17 +1016,15 @@ static bool load_checkpoint_lora_file(const char * filename, struct my_llama_mod return true; } -static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train, const char * pattern_it, int iteration, const char * latest) { - std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); - std::string fn = replace_str(filename, pattern_it, sit.c_str()); - printf("%s: saving to %s\n", __func__, fn.c_str()); +static void save_checkpoint_lora_file(const char * filename, struct my_llama_model * model, struct my_llama_lora * lora, struct train_state * train) { + printf("%s: saving to %s\n", __func__, filename); struct gguf_context * fctx = gguf_init_empty(); save_checkpoint_lora_gguf(fctx, model, lora, train); // write file const bool only_meta = false; - gguf_write_to_file(fctx, fn.c_str(), only_meta); + gguf_write_to_file(fctx, filename, only_meta); gguf_free(fctx); } @@ -1139,11 +1137,9 @@ static void write_tensor(struct llama_file * file, struct ggml_tensor * tensor, file->write_raw(tensor->data, ggml_nbytes(tensor)); } -static void save_as_llama_lora(struct my_llama_lora * lora, const char * filename, const char * pattern_it, int iteration, const char * latest) { - std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); - std::string fn = replace_str(filename, pattern_it, sit.c_str()); - printf("%s: saving to %s\n", __func__, fn.c_str()); - struct llama_file file(fn.c_str(), "wb"); +static void save_as_llama_lora(const char * filename, struct my_llama_lora * lora) { + printf("%s: saving to %s\n", __func__, filename); + struct llama_file file(filename, "wb"); if (file.fp == NULL) { return; } @@ -1823,25 +1819,49 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par return true; } -struct opt_callback_data { - struct train_params * params; - struct train_state * train; +struct save_train_files_data { + const char * fn_checkpoint_out; + const char * fn_lora_out; + const char * pattern_fn_it; + const char * fn_latest; struct my_llama_model * model; struct my_llama_lora * lora; - struct llama_context * lctx; - int last_save_iter; - llama_token * tokens_data; - size_t tokens_size; - size_t * samples_begin; - size_t * samples_size; - size_t * shuffled_samples_begin; - size_t * shuffled_samples_size; - size_t samples_count; - struct ggml_tensor * tokens_input; - struct ggml_tensor * target_probs; - int first_iter; - int64_t last_time; - double millis_per_iter; +}; + +static void save_train_files(void * vdata, struct train_state * train) { + struct save_train_files_data * data = (struct save_train_files_data *) vdata; + + int64_t iter = train->opt->iter; + + if (strlen(data->fn_checkpoint_out) > 0) { + save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->model, data->lora, train); + save_checkpoint_lora_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->model, data->lora, train); + } + if (strlen(data->fn_lora_out) > 0) { + save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->lora); + save_as_llama_lora(get_train_filename(data->fn_lora_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->lora); + } +} + +struct opt_callback_data { + struct train_params * params; + struct train_state * train; + save_train_files_callback save_cb; + void * save_data; + struct llama_context * lctx; + int last_save_iter; + llama_token * tokens_data; + size_t tokens_size; + size_t * samples_begin; + size_t * samples_size; + size_t * shuffled_samples_begin; + size_t * shuffled_samples_size; + size_t samples_count; + struct ggml_tensor * tokens_input; + struct ggml_tensor * target_probs; + int first_iter; + int64_t last_time; + double millis_per_iter; }; static void opt_callback(void * vdata, int accum_step, float * sched) { @@ -1881,14 +1901,10 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; - if (strlen(params->fn_checkpoint_out) > 0) { - save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, train, params->pattern_fn_it, opt->iter, params->fn_latest); - save_checkpoint_lora_file(params->fn_checkpoint_out, data->model, data->lora, train, params->pattern_fn_it, -1, params->fn_latest); - } - if (strlen(params->fn_lora_out) > 0) { - save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter, params->fn_latest); - save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, -1, params->fn_latest); + if (data->save_cb) { + data->save_cb(data->save_data, train); } + data->last_save_iter = opt->iter; } @@ -2140,10 +2156,17 @@ int main(int argc, char ** argv) { opt->iter = train->train_its; if (params.only_write_lora) { - if (strlen(params.fn_lora_out) > 0) { - save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter, params.fn_latest); - save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest); - } + save_train_files_data save_data; + save_data.fn_checkpoint_out = ""; + save_data.fn_lora_out = params.fn_lora_out; + save_data.pattern_fn_it = params.pattern_fn_it; + save_data.fn_latest = params.fn_latest; + save_data.model = &model; + save_data.lora = &lora; + + save_train_files(&save_data, train); + + free_train_state(train); ggml_free(lora.ctx); llama_free(lctx); llama_free_model(lmodel); @@ -2323,12 +2346,20 @@ int main(int argc, char ** argv) { printf("%s: begin training\n", __func__); + save_train_files_data save_data; + save_data.fn_checkpoint_out = params.fn_checkpoint_out; + save_data.fn_lora_out = params.fn_lora_out; + save_data.pattern_fn_it = params.pattern_fn_it; + save_data.fn_latest = params.fn_latest; + save_data.model = &model; + save_data.lora = &lora; + struct opt_callback_data opt_cb_data; - opt_cb_data.params = ¶ms; - opt_cb_data.train = train; - opt_cb_data.model = &model; - opt_cb_data.lora = &lora; - opt_cb_data.lctx = lctx; + opt_cb_data.params = ¶ms; + opt_cb_data.train = train; + opt_cb_data.save_cb = &save_train_files; + opt_cb_data.save_data = &save_data; + opt_cb_data.lctx = lctx; opt_cb_data.last_save_iter = opt->iter; opt_cb_data.tokens_data = train_tokens.data(); opt_cb_data.tokens_size = train_tokens.size(); @@ -2374,14 +2405,7 @@ int main(int argc, char ** argv) { train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; - if (strlen(params.fn_checkpoint_out) > 0) { - save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, train, params.pattern_fn_it, opt->iter, params.fn_latest); - save_checkpoint_lora_file(params.fn_checkpoint_out, &model, &lora, train, params.pattern_fn_it, -1, params.fn_latest); - } - if (strlen(params.fn_lora_out) > 0) { - save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter, params.fn_latest); - save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, -1, params.fn_latest); - } + save_train_files(&save_data, train); opt_cb_data.last_save_iter = opt->iter; } diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index bead80843..7984dd724 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -640,17 +640,15 @@ static void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vo } } -static void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, const char * pattern_it, int iteration, const char * latest) { - std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); - std::string fn = replace_str(filename, pattern_it, sit.c_str()); - printf("%s: saving to %s\n", __func__, fn.c_str()); +static void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) { + printf("%s: saving to %s\n", __func__, filename); struct gguf_context * fctx = gguf_init_empty(); save_llama_model_gguf(fctx, fn_vocab_model, model); // write file const bool only_meta = false; - gguf_write_to_file(fctx, fn.c_str(), only_meta); + gguf_write_to_file(fctx, filename, only_meta); gguf_free(fctx); } @@ -681,17 +679,15 @@ static bool load_checkpoint_file(const char * filename, struct my_llama_model * return true; } -static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train, const char * pattern_it, int iteration, const char * latest) { - std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); - std::string fn = replace_str(filename, pattern_it, sit.c_str()); - printf("%s: saving to %s\n", __func__, fn.c_str()); +static void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct train_state * train) { + printf("%s: saving to %s\n", __func__, filename); struct gguf_context * fctx = gguf_init_empty(); save_checkpoint_gguf(fctx, fn_vocab_model, model, train); // write file const bool only_meta = false; - gguf_write_to_file(fctx, fn.c_str(), only_meta); + gguf_write_to_file(fctx, filename, only_meta); gguf_free(fctx); } @@ -1220,25 +1216,50 @@ static bool train_params_parse(int argc, char ** argv, struct train_params * par return true; } -struct opt_callback_data { - struct train_params * params; - struct train_state * train; +struct save_train_files_data { + const char * fn_checkpoint_out; + const char * fn_model_out; + const char * fn_vocab_model; + const char * pattern_fn_it; + const char * fn_latest; struct my_llama_model * model; - struct llama_context * lctx; - int last_save_iter; - llama_token * tokens_data; - size_t tokens_size; - size_t * samples_begin; - size_t * samples_size; - size_t * shuffled_samples_begin; - size_t * shuffled_samples_size; - size_t samples_count; - struct ggml_tensor * tokens_input; - struct ggml_tensor * target_logits; - struct ggml_tensor * target_probs; - int first_iter; - int64_t last_time; - double millis_per_iter; +}; + +static void save_train_files(void * vdata, struct train_state * train) { + struct save_train_files_data * data = (struct save_train_files_data *) vdata; + int64_t iter = train->opt->iter; + + if (strlen(data->fn_checkpoint_out) > 0) { + save_checkpoint_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->fn_vocab_model, data->model, train); + save_checkpoint_file(get_train_filename(data->fn_checkpoint_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->fn_vocab_model, data->model, train); + + } + if (strlen(data->fn_model_out) > 0) { + save_llama_model_file(get_train_filename(data->fn_model_out, data->pattern_fn_it, data->fn_latest, iter).c_str(), data->fn_vocab_model, data->model); + save_llama_model_file(get_train_filename(data->fn_model_out, data->pattern_fn_it, data->fn_latest, -1 ).c_str(), data->fn_vocab_model, data->model); + } +} + +struct opt_callback_data { + struct train_params * params; + struct train_state * train; + save_train_files_callback save_cb; + void * save_data; + struct llama_context * lctx; + int last_save_iter; + llama_token * tokens_data; + size_t tokens_size; + size_t * samples_begin; + size_t * samples_size; + size_t * shuffled_samples_begin; + size_t * shuffled_samples_size; + size_t samples_count; + struct ggml_tensor * tokens_input; + struct ggml_tensor * target_logits; + struct ggml_tensor * target_probs; + int first_iter; + int64_t last_time; + double millis_per_iter; }; static void opt_callback(void * vdata, int accum_step, float * sched) { @@ -1278,15 +1299,10 @@ static void opt_callback(void * vdata, int accum_step, float * sched) { train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_ctx; - if (strlen(params->fn_checkpoint_out) > 0) { - save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, train, params->pattern_fn_it, opt->iter, params->fn_latest); - save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, train, params->pattern_fn_it, -1, params->fn_latest); + if (data->save_cb) { + data->save_cb(data->save_data, train); + } - } - if (strlen(params->fn_model_out) > 0) { - save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, opt->iter, params->fn_latest); - save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, -1, params->fn_latest); - } data->last_save_iter = opt->iter; } @@ -1508,14 +1524,23 @@ int main(int argc, char ** argv) { train_samples_size.size()); printf("%s: begin training\n", __func__); + save_train_files_data save_data; + save_data.fn_checkpoint_out = params.fn_checkpoint_out; + save_data.fn_model_out = params.fn_model_out; + save_data.fn_vocab_model = params.fn_vocab_model; + save_data.pattern_fn_it = params.pattern_fn_it; + save_data.fn_latest = params.fn_latest; + save_data.model = &model; + struct opt_callback_data opt_cb_data; - opt_cb_data.params = ¶ms; - opt_cb_data.train = train; - opt_cb_data.model = &model; - opt_cb_data.lctx = lctx; - opt_cb_data.last_save_iter = opt->iter; - opt_cb_data.tokens_data = train_tokens.data(); - opt_cb_data.tokens_size = train_tokens.size(); + opt_cb_data.params = ¶ms; + opt_cb_data.train = train; + opt_cb_data.save_cb = &save_train_files; + opt_cb_data.save_data = &save_data; + opt_cb_data.lctx = lctx; + opt_cb_data.last_save_iter = opt->iter; + opt_cb_data.tokens_data = train_tokens.data(); + opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.samples_begin = train_samples_begin.data(); opt_cb_data.samples_size = train_samples_size.data(); opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data(); @@ -1620,22 +1645,15 @@ int main(int argc, char ** argv) { printf("%s: total training time=%f seconds\n", __func__, dd); int new_iters = opt->iter - opt_cb_data.last_save_iter; - train->train_its += new_iters; - train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; - train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; + if (new_iters > 0) { + train->train_its += new_iters; + train->train_samples += new_iters * opt->params.n_gradient_accumulation * n_batch; + train->train_tokens += new_iters * opt->params.n_gradient_accumulation * n_batch * n_tokens; - if (params.n_examples > 0) { - save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, train, params.pattern_fn_it, opt->iter, params.fn_latest); - save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, train, params.pattern_fn_it, -1, params.fn_latest); + save_train_files(&save_data, train); + opt_cb_data.last_save_iter = opt->iter; } - if (strlen(params.fn_model_out) > 0) { - save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model, params.pattern_fn_it, opt->iter, params.fn_latest); - save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model, params.pattern_fn_it, -1, params.fn_latest); - } - - opt_cb_data.last_save_iter = opt->iter; - if (alloc) { ggml_allocr_free(alloc); }