save & load opt->just_initialized value
This commit is contained in:
parent
2978e03086
commit
0c494cc60e
1 changed files with 18 additions and 8 deletions
|
@ -244,6 +244,7 @@ const char * LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT = "optimizer.convergenc
|
|||
const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT_LOW = "optimizer.parameter_count.low";
|
||||
const char * LLM_KV_OPTIMIZER_PARAMETER_COUNT_HIGH = "optimizer.parameter_count.high";
|
||||
const char * LLM_KV_OPTIMIZER_ITERATION_COUNT = "optimizer.iteration_count";
|
||||
const char * LLM_KV_OPTIMIZER_JUST_INITIALIZED = "optimizer.just_initialized";
|
||||
const char * LLM_KV_OPTIMIZER_ADAM_BEST_LOSS = "optimizer.adam.best_loss";
|
||||
const char * LLM_KV_OPTIMIZER_ADAM_PREVIOUS_LOSS = "optimizer.adam.previous_loss";
|
||||
const char * LLM_KV_OPTIMIZER_ADAM_NO_IMPROVEMENT_COUNT = "optimizer.adam.no_improvement_count";
|
||||
|
@ -1527,6 +1528,7 @@ void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_g
|
|||
|
||||
GGUF_GET_KEY(fctx, opt->params.past, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_CONVERGENCE_PAST_COUNT);
|
||||
GGUF_GET_KEY(fctx, opt->iter, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_OPTIMIZER_ITERATION_COUNT);
|
||||
GGUF_GET_KEY(fctx, opt->just_initialized, gguf_get_val_bool, GGUF_TYPE_BOOL, true, LLM_KV_OPTIMIZER_JUST_INITIALIZED);
|
||||
|
||||
// gguf v1 only supports values with up to 32-bit precision
|
||||
uint32_t nx[2] = { 0, 0 };
|
||||
|
@ -1587,12 +1589,14 @@ void save_opt_context_gguf(struct gguf_context * fctx, struct ggml_opt_context *
|
|||
|
||||
// gguf v1 only supports values with up to 32-bit precision,
|
||||
uint32_t nx[2] = { 0, 0 };
|
||||
memcpy(&nx[0], &opt->nx, sizeof(opt->nx));
|
||||
nx[0] = opt->nx & 0xFFFFFFFF;
|
||||
nx[1] = (opt->nx >> 32) & 0xFFFFFFFF;
|
||||
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT_LOW, nx[0]);
|
||||
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_PARAMETER_COUNT_HIGH, nx[1]);
|
||||
// TODO set as 64-bit uint
|
||||
|
||||
gguf_set_val_u32(fctx, LLM_KV_OPTIMIZER_ITERATION_COUNT, opt->iter);
|
||||
gguf_set_val_bool(fctx, LLM_KV_OPTIMIZER_JUST_INITIALIZED, opt->just_initialized);
|
||||
|
||||
switch (opt->params.type) {
|
||||
case GGML_OPT_ADAM:
|
||||
|
@ -1685,18 +1689,24 @@ void load_llama_model_gguf(struct gguf_context * fctx, struct ggml_context * f_g
|
|||
GGUF_GET_KEY(fctx, ftype_u, gguf_get_val_u32, GGUF_TYPE_UINT32, true, LLM_KV_GENERAL_FILE_TYPE);
|
||||
GGML_ASSERT((enum llama_ftype) ftype_u == LLAMA_FTYPE_ALL_F32);
|
||||
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_CONTEXT_LENGTH));
|
||||
// n_ctx was not saved in earlier checkpoint file versions, so we make it optional here
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_CONTEXT_LENGTH));
|
||||
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_EMBEDDING_LENGTH));
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_FEED_FORWARD_LENGTH));
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ROPE_DIMENSION_COUNT));
|
||||
|
||||
float rope_freq_scale;
|
||||
GGUF_GET_KEY(fctx, model->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
||||
GGUF_GET_KEY(fctx, model->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ROPE_FREQ_BASE));
|
||||
GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
||||
model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head;
|
||||
GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
|
||||
|
||||
float rope_freq_scale = 1.0f;
|
||||
GGUF_GET_KEY(fctx, model->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
||||
GGUF_GET_KEY(fctx, model->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
|
||||
GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
||||
if (rope_freq_scale != 1.0f) {
|
||||
model->hparams.rope_freq_scale = 1.0f / rope_freq_scale;
|
||||
}
|
||||
|
||||
init_model(model);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue