diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 67c2aef34..ad5631da1 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -1856,8 +1856,19 @@ void read_opt_context(struct llama_file * file, struct ggml_context * ctx, struc } } -void save_checkpoint(struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt, const char * filename) { - struct llama_file file(filename, "wb"); +std::string replace_str(const char * s, const char * needle, const char * replacement) { + std::string str = s; + size_t pos = str.find(needle); + if (pos != std::string::npos) { + str.replace(pos, strlen(needle), replacement); + } + return str; +} + +void save_checkpoint(struct my_llama_model * model, struct my_llama_lora * lora, struct ggml_opt_context * opt, const char * filename, const char * pattern_it, int iteration) { + std::string sit = std::to_string(iteration); + std::string fn = replace_str(filename, pattern_it, sit.c_str()); + struct llama_file file(fn.c_str(), "wb"); if (file.fp == NULL) { return; } @@ -2021,8 +2032,10 @@ bool load_checkpoint(struct my_llama_model * model, struct my_llama_lora * lora, return (file.fp != NULL); } -void save_as_llama_lora(struct my_llama_lora * lora, const char * filename) { - struct llama_file file(filename, "wb"); +void save_as_llama_lora(struct my_llama_lora * lora, const char * filename, const char * pattern_it, int iteration) { + std::string sit = std::to_string(iteration); + std::string fn = replace_str(filename, pattern_it, sit.c_str()); + struct llama_file file(fn.c_str(), "wb"); if (file.fp == NULL) { return; } @@ -2088,6 +2101,9 @@ struct train_params { const char * fn_checkpoint_in; const char * fn_checkpoint_out; const char * fn_lora_out; + const char * pattern_fn_it; + + int save_every; uint32_t seed; @@ -2154,8 +2170,11 @@ struct train_params get_default_train_params() { params.fn_model_base = ""; params.fn_train_data = "shakespeare.txt"; params.fn_checkpoint_in = "checkpoint.bin"; - params.fn_checkpoint_out = "checkpoint.bin"; - params.fn_lora_out = "ggml-lora-f32.bin"; + params.fn_checkpoint_out = "checkpoint-ITERATION.bin"; + params.fn_lora_out = "ggml-lora-ITERATION-f32.bin"; + params.pattern_fn_it = "ITERATION"; + + params.save_every = 10; params.seed = -1; @@ -2228,6 +2247,8 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in); fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out); fprintf(stderr, " --lora-out FNAME path to save llama lora (default '%s')\n", params->fn_lora_out); + fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it); + fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%s')\n", params->save_every); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n"); fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); fprintf(stderr, " -t N, --threads N Number of threads (default %d)\n", params->n_threads); @@ -2325,6 +2346,18 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->fn_lora_out = argv[i]; + } else if (arg == "--pattern-fn-it") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->pattern_fn_it = argv[i]; + } else if (arg == "--save-every") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->save_every = std::stoi(argv[i]); } else if (arg == "-s" || arg == "--seed") { if (++i >= argc) { invalid_param = true; @@ -2614,6 +2647,9 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { struct opt_callback_data { struct train_params * params; struct ggml_opt_context * opt; + struct my_llama_model * model; + struct my_llama_lora * lora; + int last_save_iter; llama_token * tokens_data; size_t tokens_size; int * samples_data; @@ -2630,6 +2666,17 @@ void opt_callback(void * vdata, float * sched) { struct ggml_opt_context * opt = data->opt; int n_batch = params->n_batch; + const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); + if (save_now) { + if (strlen(params->fn_checkpoint_out) > 0) { + save_checkpoint(data->model, data->lora, opt, params->fn_checkpoint_out, params->pattern_fn_it, opt->iter); + } + if (strlen(params->fn_lora_out) > 0) { + save_as_llama_lora(data->lora, params->fn_lora_out, params->pattern_fn_it, opt->iter); + } + data->last_save_iter = opt->iter; + } + *sched = (opt->iter < params->warmup) ? (float) opt->iter / (float) params->warmup : cosine_decay_restart( @@ -2854,6 +2901,9 @@ int main(int argc, char ** argv) { struct opt_callback_data opt_cb_data; opt_cb_data.params = ¶ms; opt_cb_data.opt = opt; + opt_cb_data.model = &model; + opt_cb_data.lora = &lora; + opt_cb_data.last_save_iter = opt->iter; opt_cb_data.tokens_data = train_tokens.data(); opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.samples_data = train_samples.data(); @@ -2988,11 +3038,11 @@ int main(int argc, char ** argv) { printf("%s: total training time=%f seconds\n", __func__, dd); if (params.n_examples > 0) { - save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out); + save_checkpoint(&model, &lora, opt, params.fn_checkpoint_out, params.pattern_fn_it, opt->iter); } if (strlen(params.fn_lora_out) > 0) { - save_as_llama_lora(&lora, params.fn_lora_out); + save_as_llama_lora(&lora, params.fn_lora_out, params.pattern_fn_it, opt->iter); } {