diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 9db0f1afa..b62c19540 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -2025,6 +2025,321 @@ void save_checkpoint_file(const char * filename, const char * fn_vocab_model, st gguf_free(fctx); } +struct llama_file { + // use FILE * so we don't have to re-open the file to mmap + FILE * fp; + size_t size; + + llama_file(const char * fname, const char * mode) { + fp = std::fopen(fname, mode); + if (fp == NULL) { + size = 0; + } else { + seek(0, SEEK_END); + size = tell(); + seek(0, SEEK_SET); + } + } + + size_t tell() const { +#ifdef _WIN32 + __int64 ret = _ftelli64(fp); +#else + long ret = std::ftell(fp); +#endif + GGML_ASSERT(ret != -1); // this really shouldn't fail + return (size_t) ret; + } + + void seek(size_t offset, int whence) { +#ifdef _WIN32 + int ret = _fseeki64(fp, (__int64) offset, whence); +#else + int ret = std::fseek(fp, (long) offset, whence); +#endif + GGML_ASSERT(ret == 0); // same + } + + void read_raw(void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + std::size_t ret = std::fread(ptr, size, 1, fp); + if (ferror(fp)) { + throw std::runtime_error(format("read error: %s", strerror(errno))); + } + if (ret != 1) { + throw std::runtime_error(std::string("unexpectedly reached end of file")); + } + } + + std::uint32_t read_u32() { + std::uint32_t ret; + read_raw(&ret, sizeof(ret)); + return ret; + } + + std::string read_string(std::uint32_t len) { + std::vector chars(len); + read_raw(chars.data(), len); + return std::string(chars.data(), len); + } + + void write_raw(const void * ptr, size_t size) { + if (size == 0) { + return; + } + errno = 0; + size_t ret = std::fwrite(ptr, size, 1, fp); + if (ret != 1) { + throw std::runtime_error(format("write error: %s", strerror(errno))); + } + } + + void write_u32(std::uint32_t val) { + write_raw(&val, sizeof(val)); + } + + ~llama_file() { + if (fp) { + std::fclose(fp); + } + } +}; + +void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) { + if (tensor == NULL) { + file->write_u32(0); + file->write_u32(0); + file->write_u32(GGML_TYPE_F32); + file->seek((0-file->tell()) & 31, SEEK_CUR); + printf("%s: write tensor name='%s' data offset='%zu' nbytes='%zu'\n", + __func__, "(empty tensor)", file->tell(), (size_t) 0); + return; + } + const char * name = ggml_get_name(tensor); + uint32_t name_len = strlen(name); + uint32_t nd = tensor->n_dims; + uint32_t ne[4] = { (uint32_t)tensor->ne[0], + (uint32_t)tensor->ne[1], + (uint32_t)tensor->ne[2], + (uint32_t)tensor->ne[3] }; + printf("%s: write tensor name='%s' begin offset='%zu'\n", + __func__, name, file->tell()); + file->write_u32(nd); + file->write_u32(name_len); + file->write_u32(tensor->type); + file->write_raw(ne, sizeof(ne[0]) * nd); + file->write_raw(name, name_len); + file->seek((0-file->tell()) & 31, SEEK_CUR); + printf("%s: write tensor name='%s' data offset='%zu' nbytes='%zu'\n", + __func__, name, file->tell(), ggml_nbytes(tensor)); + file->write_raw(tensor->data, ggml_nbytes(tensor)); +} + +struct ggml_opt_params_v0 { + enum ggml_opt_type type; + int n_threads; + int past; + float delta; + int max_no_improvement; + bool print_forward_graph; + bool print_backward_graph; + struct { + int n_iter; + float sched; + float decay; + float alpha; + float beta1; + float beta2; + float eps; + float eps_f; + float eps_g; + } adam; + struct { + int m; + int n_iter; + int max_linesearch; + float eps; + float ftol; + float wolfe; + float min_step; + float max_step; + enum ggml_linesearch linesearch; + } lbfgs; +}; + +void write_opt_context_v0(struct llama_file * file, struct ggml_opt_context * opt) { + const uint32_t version = 0; + GGML_ASSERT(opt->nx >= 0); + GGML_ASSERT(opt->iter >= 0); + file->write_u32(version); + ggml_opt_params_v0 params_v0; + params_v0.type = opt->params.type; + params_v0.n_threads = opt->params.n_threads; + params_v0.past = opt->params.past; + params_v0.delta = opt->params.delta; + params_v0.max_no_improvement = opt->params.max_no_improvement; + params_v0.print_forward_graph = opt->params.print_forward_graph; + params_v0.print_backward_graph = opt->params.print_backward_graph; + params_v0.adam.n_iter = opt->params.adam.n_iter; + params_v0.adam.sched = opt->params.adam.sched; + params_v0.adam.decay = opt->params.adam.decay; + params_v0.adam.alpha = opt->params.adam.alpha; + params_v0.adam.beta1 = opt->params.adam.beta1; + params_v0.adam.beta2 = opt->params.adam.beta2; + params_v0.adam.eps = opt->params.adam.eps; + params_v0.adam.eps_f = opt->params.adam.eps_f; + params_v0.adam.eps_g = opt->params.adam.eps_g; + params_v0.lbfgs.m = opt->params.lbfgs.m; + params_v0.lbfgs.n_iter = opt->params.lbfgs.n_iter; + params_v0.lbfgs.max_linesearch = opt->params.lbfgs.max_linesearch; + params_v0.lbfgs.eps = opt->params.lbfgs.eps; + params_v0.lbfgs.ftol = opt->params.lbfgs.ftol; + params_v0.lbfgs.wolfe = opt->params.lbfgs.wolfe; + params_v0.lbfgs.min_step = opt->params.lbfgs.min_step; + params_v0.lbfgs.max_step = opt->params.lbfgs.max_step; + file->write_raw(¶ms_v0, sizeof(params_v0)); + file->write_raw(&opt->nx, sizeof(opt->nx)); + file->write_raw(&opt->iter, sizeof(opt->iter)); + file->write_u32((uint32_t) opt->just_initialized); + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + struct ggml_tensor * adam_x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, opt->nx); + struct ggml_tensor * adam_g1 = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, opt->nx); + struct ggml_tensor * adam_g2 = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, opt->nx); + struct ggml_tensor * adam_mh = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, opt->nx); + struct ggml_tensor * adam_vh = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, opt->nx); + write_tensor(file, adam_x); + write_tensor(file, adam_g1); + write_tensor(file, adam_g2); + write_tensor(file, opt->adam.m); + write_tensor(file, opt->adam.v); + write_tensor(file, adam_mh); + write_tensor(file, adam_vh); + write_tensor(file, opt->adam.pf); + file->write_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best)); + file->write_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev)); + file->write_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement)); + } break; + case GGML_OPT_LBFGS: + { + write_tensor(file, opt->lbfgs.x); + write_tensor(file, opt->lbfgs.xp); + write_tensor(file, opt->lbfgs.g); + write_tensor(file, opt->lbfgs.gp); + write_tensor(file, opt->lbfgs.d); + write_tensor(file, opt->lbfgs.pf); + write_tensor(file, opt->lbfgs.lmal); + write_tensor(file, opt->lbfgs.lmys); + write_tensor(file, opt->lbfgs.lms); + write_tensor(file, opt->lbfgs.lmy); + file->write_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best)); + file->write_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step)); + file->write_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j)); + file->write_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k)); + file->write_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end)); + file->write_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement)); + } break; + } +} + +void write_opt_context_v1(struct llama_file * file, struct ggml_opt_context * opt) { + const uint32_t version = 1; + GGML_ASSERT(opt->nx >= 0); + GGML_ASSERT(opt->iter >= 0); + file->write_u32(version); + file->write_u32(opt->params.past); + file->write_u32(opt->params.lbfgs.m); + file->write_raw(&opt->nx, sizeof(opt->nx)); + file->write_raw(&opt->iter, sizeof(opt->iter)); + file->write_u32((uint32_t) opt->just_initialized); + switch (opt->params.type) { + case GGML_OPT_ADAM: + { + GGML_ASSERT(opt->adam.m != NULL); + GGML_ASSERT(opt->adam.v != NULL); + write_tensor(file, opt->adam.m); + write_tensor(file, opt->adam.v); + write_tensor(file, opt->adam.pf); + file->write_raw(&opt->adam.fx_best, sizeof(opt->adam.fx_best)); + file->write_raw(&opt->adam.fx_prev, sizeof(opt->adam.fx_prev)); + file->write_raw(&opt->adam.n_no_improvement, sizeof(opt->adam.n_no_improvement)); + } break; + case GGML_OPT_LBFGS: + { + GGML_ASSERT(opt->lbfgs.x != NULL); + write_tensor(file, opt->lbfgs.x); + write_tensor(file, opt->lbfgs.xp); + write_tensor(file, opt->lbfgs.g); + write_tensor(file, opt->lbfgs.gp); + write_tensor(file, opt->lbfgs.d); + write_tensor(file, opt->lbfgs.pf); + write_tensor(file, opt->lbfgs.lmal); + write_tensor(file, opt->lbfgs.lmys); + write_tensor(file, opt->lbfgs.lms); + write_tensor(file, opt->lbfgs.lmy); + file->write_raw(&opt->lbfgs.fx_best, sizeof(opt->lbfgs.fx_best)); + file->write_raw(&opt->lbfgs.step, sizeof(opt->lbfgs.step)); + file->write_raw(&opt->lbfgs.j, sizeof(opt->lbfgs.j)); + file->write_raw(&opt->lbfgs.k, sizeof(opt->lbfgs.k)); + file->write_raw(&opt->lbfgs.end, sizeof(opt->lbfgs.end)); + file->write_raw(&opt->lbfgs.n_no_improvement, sizeof(opt->lbfgs.n_no_improvement)); + } break; + } +} + +void save_checkpoint(struct my_llama_model * model, struct ggml_opt_context * opt, const char * filename, int opt_version) { + struct llama_file file(filename, "wb"); + if (file.fp == NULL) { + return; + } + + const uint32_t magic = 'ggcp'; + const uint32_t version = 0; + + file.write_u32(magic); + file.write_u32(version); + file.write_u32(model->train_its); + file.write_u32(model->train_samples); + file.write_u32(model->train_tokens); + file.write_u32(model->hparams.n_vocab); + file.write_u32(model->hparams.n_embd); + file.write_u32(/*model->hparams.n_mult*/ 256); + file.write_u32(model->hparams.n_head); + file.write_u32(model->hparams.n_layer); + file.write_u32(model->hparams.n_rot); + + write_tensor(&file, model->tok_embeddings); + write_tensor(&file, model->norm); + write_tensor(&file, model->output); + + for (uint32_t i = 0; i < model->hparams.n_layer; ++i) { + auto & layer = model->layers[i]; + + write_tensor(&file, layer.attention_norm); + write_tensor(&file, layer.wq); + write_tensor(&file, layer.wk); + write_tensor(&file, layer.wv); + write_tensor(&file, layer.wo); + write_tensor(&file, layer.ffn_norm); + write_tensor(&file, layer.w1); + write_tensor(&file, layer.w2); + write_tensor(&file, layer.w3); + } + + if (opt_version == 0) { + write_opt_context_v0(&file, opt); + } else { + write_opt_context_v1(&file, opt); + } + + printf("%s: all written offset='%zu'\n", + __func__, file.tell()); + +} float cosine_decay(const int decay_steps, const float minimum, int step) { if (step > decay_steps) { step = decay_steps; @@ -2875,6 +3190,15 @@ int main(int argc, char ** argv) { printf("%s: total training time=%f seconds\n", __func__, dd); if (params.n_examples > 0) { + for (int opt_version = 0; opt_version < 2; ++opt_version) { + std::string fn_checkpoint_out_old = ( + std::string(params.fn_checkpoint_out) + + std::string(".") + + std::to_string(opt_version) + + std::string(".old.bin")); + save_checkpoint(&model, opt, fn_checkpoint_out_old.c_str(), opt_version); + } + save_checkpoint_file(params.fn_checkpoint_out, params.fn_vocab_model, &model, opt); }