diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 76cf501a5..2a9cf9c7c 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -793,6 +793,15 @@ void shuffle_ints(int * begin, int * end) { }); } +std::string replace_str(const char * s, const char * needle, const char * replacement) { + std::string str = s; + size_t pos = str.find(needle); + if (pos != std::string::npos) { + str.replace(pos, strlen(needle), replacement); + } + return str; +} + #define GGUF_GET_KEY(ctx, dst, func, type, req, key) \ { \ const std::string skey(key); \ @@ -1174,14 +1183,17 @@ void save_llama_model_gguf(struct gguf_context * fctx, const char * fn_vocab_mod } } -void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model) { +void save_llama_model_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, const char * pattern_it, int iteration, const char * latest) { + std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); + std::string fn = replace_str(filename, pattern_it, sit.c_str()); + printf("%s: saving to %s\n", __func__, fn.c_str()); struct gguf_context * fctx = gguf_init_empty(); save_llama_model_gguf(fctx, fn_vocab_model, model); // write file const bool only_meta = false; - gguf_write_to_file(fctx, filename, only_meta); + gguf_write_to_file(fctx, fn.c_str(), only_meta); gguf_free(fctx); } @@ -1234,14 +1246,17 @@ bool load_checkpoint_file(const char * filename, struct my_llama_model * model, return true; } -void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt) { +void save_checkpoint_file(const char * filename, const char * fn_vocab_model, struct my_llama_model * model, struct ggml_opt_context * opt, const char * pattern_it, int iteration, const char * latest) { + std::string sit = (iteration >= 0) ? std::to_string(iteration) : std::string(latest); + std::string fn = replace_str(filename, pattern_it, sit.c_str()); + printf("%s: saving to %s\n", __func__, fn.c_str()); struct gguf_context * fctx = gguf_init_empty(); save_checkpoint_gguf(fctx, fn_vocab_model, model, opt); // write file const bool only_meta = false; - gguf_write_to_file(fctx, filename, only_meta); + gguf_write_to_file(fctx, fn.c_str(), only_meta); gguf_free(fctx); } @@ -1270,6 +1285,10 @@ struct train_params { const char * fn_checkpoint_in; const char * fn_checkpoint_out; const char * fn_model_out; + const char * pattern_fn_it; + const char * fn_latest; + + int save_every; uint32_t seed; @@ -1329,6 +1348,10 @@ struct train_params get_default_train_params() { params.fn_checkpoint_in = "checkpoint.bin"; params.fn_checkpoint_out = "checkpoint.bin"; params.fn_model_out = "ggml-checkpoint-f32.bin"; + params.pattern_fn_it = "ITERATION"; + params.fn_latest = "LATEST"; + + params.save_every = 10; params.seed = -1; @@ -1392,6 +1415,9 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --checkpoint-in FNAME path from which to load training checkpoint (default '%s')\n", params->fn_checkpoint_in); fprintf(stderr, " --checkpoint-out FNAME path to save training checkpoint (default '%s')\n", params->fn_checkpoint_out); fprintf(stderr, " --model-out FNAME path to save ggml model (default '%s')\n", params->fn_model_out); + fprintf(stderr, " --pattern-fn-it STR pattern in output filenames to be replaced by iteration number (default '%s')\n", params->pattern_fn_it); + fprintf(stderr, " --fn-latest STR string to use instead of iteration number for saving latest output (default '%s')\n", params->fn_latest); + fprintf(stderr, " --save-every N save checkpoint and lora every N iterations. Disabled when N <= 0. (default '%d')\n", params->save_every); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for -1)\n"); fprintf(stderr, " -c N, --ctx N Context size used during training (default %d)\n", params->n_ctx); fprintf(stderr, " --embd N Embedding size used for new models (default %d)\n", params->n_embd); @@ -1481,6 +1507,24 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->fn_model_out = argv[i]; + } else if (arg == "--pattern-fn-it") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->pattern_fn_it = argv[i]; + } else if (arg == "--fn-latest") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->fn_latest = argv[i]; + } else if (arg == "--save-every") { + if (++i >= argc) { + invalid_param = true; + break; + } + params->save_every = std::stoi(argv[i]); } else if (arg == "-s" || arg == "--seed") { if (++i >= argc) { invalid_param = true; @@ -1722,7 +1766,9 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { struct opt_callback_data { struct train_params * params; struct ggml_opt_context * opt; + struct my_llama_model * model; struct llama_context * lctx; + int last_save_iter; llama_token * tokens_data; size_t tokens_size; int * samples_data; @@ -1738,6 +1784,26 @@ void opt_callback(void * vdata, float * sched) { struct train_params * params = data->params; struct ggml_opt_context * opt = data->opt; int n_batch = params->n_batch; + int n_ctx = params->n_ctx; + + const bool save_now = (params->save_every > 0) && (opt->iter - data->last_save_iter >= params->save_every); + if (save_now) { + int new_iters = opt->iter - data->last_save_iter; + data->model->train_its += new_iters; + data->model->train_samples += new_iters * n_batch; + data->model->train_tokens += new_iters * n_batch * n_ctx; + + if (strlen(params->fn_checkpoint_out) > 0) { + save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, opt->iter, params->fn_latest); + save_checkpoint_file(params->fn_checkpoint_out, params->fn_vocab_model, data->model, opt, params->pattern_fn_it, -1, params->fn_latest); + + } + if (strlen(params->fn_model_out) > 0) { + save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, opt->iter, params->fn_latest); + save_llama_model_file(params->fn_model_out, params->fn_vocab_model, data->model, params->pattern_fn_it, -1, params->fn_latest); + } + data->last_save_iter = opt->iter; + } *sched = (opt->iter < params->warmup) ? (float) opt->iter / (float) params->warmup @@ -1929,7 +1995,9 @@ int main(int argc, char ** argv) { struct opt_callback_data opt_cb_data; opt_cb_data.params = ¶ms; opt_cb_data.opt = opt; + opt_cb_data.model = &model; opt_cb_data.lctx = lctx; + opt_cb_data.last_save_iter = opt->iter; opt_cb_data.tokens_data = train_tokens.data(); opt_cb_data.tokens_size = train_tokens.size(); opt_cb_data.samples_data = train_samples.data(); @@ -2038,14 +2106,23 @@ int main(int argc, char ** argv) { double dd = (double) d * 1e-3; printf("%s: total training time=%f seconds\n", __func__, dd); + int new_iters = opt->iter - opt_cb_data.last_save_iter; + model.train_its += new_iters; + model.train_samples += new_iters * n_batch; + model.train_tokens += new_iters * n_batch * n_tokens; + if (params.n_examples > 0) { - save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt); + save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, opt->iter, params.fn_latest); + save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt, params.pattern_fn_it, -1, params.fn_latest); } if (strlen(params.fn_model_out) > 0) { - save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model); + save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model, params.pattern_fn_it, opt->iter, params.fn_latest); + save_llama_model_file(params.fn_model_out, params.fn_vocab_model, &model, params.pattern_fn_it, -1, params.fn_latest); } + opt_cb_data.last_save_iter = opt->iter; + if (alloc) { ggml_allocr_free(alloc); }