From 0c494cc60ebef97a5ba847203f78fae4a96491ad Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 27 Aug 2023 22:05:24 +0200 Subject: [PATCH] save & load opt->just_initialized value --- .../train-text-from-scratch.cpp | 26 +++++++++++++------ 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index c7345c300..0a32cc248 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -244,6 +244,7 @@ const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergenc const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT_LOW = "optimizer.parameter_count.low"; const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT_HIGH = "optimizer.parameter_count.high"; const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count"; +const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized"; const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss"; const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss"; const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count"; @@ -1527,6 +1528,7 @@ void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_g GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT); GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT); + GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED); // gguf v1 only supports values with up to 32-bit precision uint32_t nx[2] = { 0, 0 }; @@ -1587,12 +1589,14 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context * // gguf v1 only supports values with up to 32-bit precision, uint32_t nx[2] = { 0, 0 }; - memcpy(&nx[0], &opt->nx, sizeof(opt->nx)); + nx[0] = opt->nx & 0xFFFFFFFF; + nx[1] = (opt->nx >> 32) & 0xFFFFFFFF; gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT_LOW, nx[0]); gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT_HIGH, nx[1]); // TODO set as 64-bit uint gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter); + gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized); switch (opt->params.type) { case GGML_OPT_ADAM: @@ -1685,18 +1689,24 @@ void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_context * f_g GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE); GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32); - GGUF_GET_KEY(fctx, model->hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH)); + // n_ctx was not saved in earlier checkpoint file versions, so we make it optional here + GGUF_GET_KEY(fctx, model->hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH)); + GGUF_GET_KEY(fctx, model->hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH)); GGUF_GET_KEY(fctx, model->hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH)); GGUF_GET_KEY(fctx, model->hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); GGUF_GET_KEY(fctx, model->hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); - GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ROPE_DIMENSION_COUNT)); - float rope_freq_scale; - GGUF_GET_KEY(fctx, model->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); - GGUF_GET_KEY(fctx, model->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ROPE_FREQ_BASE)); - GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ROPE_SCALE_LINEAR)); - model->hparams.rope_freq_scale = 1.0f / rope_freq_scale; + model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head; + GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); + + float rope_freq_scale = 1.0f; + GGUF_GET_KEY(fctx, model->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + GGUF_GET_KEY(fctx, model->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); + GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + if (rope_freq_scale != 1.0f) { + model->hparams.rope_freq_scale = 1.0f / rope_freq_scale; + } init_model(model);