From b241b9cb6cf000c0d3e16e4aecda11bb08e77537 Mon Sep 17 00:00:00 2001 From: xaedes Date: Wed, 17 May 2023 13:49:32 +0200 Subject: [PATCH] save train trained model to checkpoint and load model to be trained from checkpoint --- examples/baby-llama/baby-llama-text.cpp | 149 ++++++++++++++++++++++-- 1 file changed, 140 insertions(+), 9 deletions(-) diff --git a/examples/baby-llama/baby-llama-text.cpp b/examples/baby-llama/baby-llama-text.cpp index 2de3171f1..542b5e386 100644 --- a/examples/baby-llama/baby-llama-text.cpp +++ b/examples/baby-llama/baby-llama-text.cpp @@ -204,6 +204,9 @@ struct my_llama_model { struct ggml_tensor * output; std::vector layers; + + uint32_t train_its = 0; + uint32_t train_samples = 0; }; uint32_t get_n_ff(const struct my_llama_hparams* hparams) { @@ -1124,11 +1127,12 @@ struct llama_file { llama_file(const char * fname, const char * mode) { fp = std::fopen(fname, mode); if (fp == NULL) { - throw std::runtime_error(format("failed to open %s: %s", fname, strerror(errno))); + size = 0; + } else { + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); } - seek(0, SEEK_END); - size = tell(); - seek(0, SEEK_SET); } size_t tell() const { @@ -1355,18 +1359,135 @@ void set_logits_masked(struct ggml_tensor * logits, std::vector& mask, flo } } +enum llama_file_version { + LLAMA_FILE_VERSION_GGML, + LLAMA_FILE_VERSION_GGMF_V1, // added version field and scores in vocab + LLAMA_FILE_VERSION_GGJT_V1, // added padding + LLAMA_FILE_VERSION_GGJT_V2, // changed quantization format +}; + +void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) { + const char * name = ggml_get_name(tensor); + uint32_t name_len = strlen(name); + uint32_t nd = tensor->n_dims; + uint32_t ne[4] = { tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3] }; + file->write_u32(nd); + file->write_u32(name_len); + file->write_u32(tensor->type); + file->write_raw(ne, sizeof(ne[0]) * nd); + file->write_raw(name, name_len); + file->seek(-file->tell() & 31, SEEK_CUR); + file->write_raw(tensor->data, ggml_nbytes(tensor)); +} + +void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) { + uint32_t nd = file->read_u32(); + GGML_ASSERT(nd == tensor->n_dims); + uint32_t name_len = file->read_u32(); + enum ggml_type type = (enum ggml_type) file->read_u32(); + GGML_ASSERT(type == tensor->type); + uint32_t ne[4]; + file->read_raw(ne, sizeof(ne[0]) * nd); + for (int i=0; ine[i]); + } + std::string name = file->read_string(name_len); + file->seek(-file->tell() & 31, SEEK_CUR); + + GGML_ASSERT(strcmp(ggml_get_name(tensor), name.c_str()) == 0); + file->read_raw(tensor->data, ggml_nbytes(tensor)); +} + +void save_model(struct my_llama_model * model, const char * filename) { + struct llama_file file(filename, "wb"); + if (file.fp == NULL) { + return; + } + file.write_u32(model->train_its); + file.write_u32(model->train_samples); + file.write_u32(model->hparams.n_vocab); + file.write_u32(model->hparams.n_embd); + file.write_u32(model->hparams.n_mult); + file.write_u32(model->hparams.n_head); + file.write_u32(model->hparams.n_layer); + file.write_u32(model->hparams.n_rot); + + write_tensor(&file, model->tok_embeddings); + write_tensor(&file, model->norm); + write_tensor(&file, model->output); + + for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { + auto & layer = model->layers[i]; + + write_tensor(&file, layer.attention_norm); + write_tensor(&file, layer.wq); + write_tensor(&file, layer.wk); + write_tensor(&file, layer.wv); + write_tensor(&file, layer.wo); + write_tensor(&file, layer.ffn_norm); + write_tensor(&file, layer.w1); + write_tensor(&file, layer.w2); + write_tensor(&file, layer.w3); + } +} + +void load_model(struct my_llama_model * model, const char * filename, bool init) { + struct llama_file file(filename, "rb"); + + if (file.fp) { + printf("%s: Loading model from '%s'.\n", __func__, filename); + model->train_its = file.read_u32(); + model->train_samples = file.read_u32(); + model->hparams.n_vocab = file.read_u32(); + model->hparams.n_embd = file.read_u32(); + model->hparams.n_mult = file.read_u32(); + model->hparams.n_head = file.read_u32(); + model->hparams.n_layer = file.read_u32(); + model->hparams.n_rot = file.read_u32(); + printf("%s: Training iterations: %u.\n", __func__, model->train_its); + printf("%s: Training samples: %u.\n", __func__, model->train_samples); + print_params(&model->hparams); + } + + if (init) { + init_model(model); + } + + if (file.fp) { + read_tensor(&file, model->tok_embeddings); + read_tensor(&file, model->norm); + read_tensor(&file, model->output); + + for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { + auto & layer = model->layers[i]; + + read_tensor(&file, layer.attention_norm); + read_tensor(&file, layer.wq); + read_tensor(&file, layer.wk); + read_tensor(&file, layer.wv); + read_tensor(&file, layer.wo); + read_tensor(&file, layer.ffn_norm); + read_tensor(&file, layer.w1); + read_tensor(&file, layer.w2); + read_tensor(&file, layer.w3); + } + } +} + int main(int argc, char ** argv) { const char * default_model = "ggml-vic7b-uncensored-q4_0.bin"; const char * default_train = "shakespeare.txt"; - const char * default_argv[3] = {argv[0], default_model, default_train}; + const char * default_checkpoint = "checkpoint.bin"; + const char * default_argv[4] = {argv[0], default_model, default_train, default_checkpoint}; - if (argc < 3) { + if (argc < 4) { fprintf(stderr, "usage: %s model training_data\n", argv[0]); //return 1; } const char * fn_model = (argc >= 2) ? argv[1] : default_argv[1]; const char * fn_train = (argc >= 3) ? argv[2] : default_argv[2]; + const char * fn_chkpt = (argc >= 4) ? argv[3] : default_argv[3]; struct llama_context_params llama_params = llama_context_default_params(); llama_params.vocab_only = true; @@ -1420,7 +1541,7 @@ int main(int argc, char ** argv) { my_llama_sampler sampler; printf("%s: init model\n", __func__); - init_model(&model); + load_model(&model, fn_chkpt, true); set_param_model(&model); randomize_model(&model, 1337, 0.0f, 1.0f, -1.0f, +1.0f); init_kv_cache(&kv_self, &model, n_batch); @@ -1498,8 +1619,16 @@ int main(int argc, char ** argv) { opt_params_lbfgs.print_backward_graph = false; opt_params_lbfgs.n_threads = gf.n_threads; opt_params_lbfgs.lbfgs.n_iter = 16; - ggml_opt(ctx0, opt_params_adam, e); - // ggml_opt(ctx0, opt_params_lbfgs, e); + + bool use_adam = true; + if (use_adam) { + ggml_opt(ctx0, opt_params_adam, e); + } else { + ggml_opt(ctx0, opt_params_lbfgs, e); + } + + model.train_its += use_adam ? opt_params_adam.adam.n_iter : opt_params_lbfgs.lbfgs.n_iter; + model.train_samples += n_batch; ggml_build_forward_expand(&gf, e); ggml_graph_compute(ctx0, &gf); @@ -1541,6 +1670,8 @@ int main(int argc, char ** argv) { ggml_free(ctx0); } + save_model(&model, fn_chkpt); + { int n_gen = 128; int sample_ctx = n_tokens - n_tokens/8;