From 50589ed6bea71fbf1a367a3dcebe3461f4bc112b Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 3 Sep 2023 20:05:54 +0200 Subject: [PATCH] load default rms_norm and rope parameters from base model --- examples/finetune/finetune.cpp | 69 ++++++++++++++++++++++++++++++++-- 1 file changed, 65 insertions(+), 4 deletions(-) diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index b9a5b0447..bba1bdb7c 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -383,6 +383,9 @@ void print_lora_params(struct my_llama_lora_hparams * params) { printf("%s: n_rank_tok_embeddings : %u\n", __func__, params->n_rank_tok_embeddings); printf("%s: n_rank_norm : %u\n", __func__, params->n_rank_norm); printf("%s: n_rank_output : %u\n", __func__, params->n_rank_output); + printf("%s: norm_rms_eps : %f\n", __func__, params->f_norm_rms_eps); + printf("%s: rope_freq_base : %f\n", __func__, params->rope_freq_base); + printf("%s: rope_freq_scale : %f\n", __func__, params->rope_freq_scale); } void init_model(struct llama_model * input, struct my_llama_model * model, uint32_t n_ctx) { @@ -1238,6 +1241,37 @@ void read_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, co } } +void load_default_lora_params_from_base_model(const char * fn_base_model, struct my_llama_lora_hparams * lora_params) { + if (strlen(fn_base_model) == 0) { + return; + } + struct gguf_init_params params; + params.no_alloc = false; + params.ctx = NULL; + struct gguf_context * fctx = gguf_init_from_file(fn_base_model, params); + if (fctx == NULL) { + return; + } + + const char * arch = "llama"; + std::vector keybuf; + keybuf.resize(512); + auto kv = [arch, &keybuf](const char * key) -> const char * { + snprintf(keybuf.data(), keybuf.size(), key, arch); + return keybuf.data(); + }; + + float rope_freq_scale = 1.0f; + GGUF_GET_KEY(fctx, lora_params->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); + GGUF_GET_KEY(fctx, lora_params->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); + GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); + if (rope_freq_scale != 1.0f) { + lora_params->rope_freq_scale = 1.0f / rope_freq_scale; + } + + gguf_free(fctx); +} + void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) { // NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read @@ -1697,12 +1731,18 @@ struct train_params { int n_threads; int n_batch; + bool custom_n_ctx; + bool only_write_lora; float f_norm_rms_eps; float rope_freq_base; float rope_freq_scale; + bool custom_f_norm_rms_eps; + bool custom_rope_freq_base; + bool custom_rope_freq_scale; + int32_t lora_r; int32_t lora_alpha; @@ -1765,12 +1805,18 @@ struct train_params get_default_train_params() { params.n_threads = 6; params.n_batch = 8; + params.custom_n_ctx = false; + params.only_write_lora = false; params.f_norm_rms_eps = 1e-5f; params.rope_freq_base = 10000.0f; params.rope_freq_scale = 1.0f; + params.custom_f_norm_rms_eps = false; + params.custom_rope_freq_base = false; + params.custom_rope_freq_scale = false; + params.lora_alpha = 4; params.lora_r = 4; @@ -1956,6 +2002,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->n_ctx = std::stoi(argv[i]); + params->custom_n_ctx = true; } else if (arg == "-t" || arg == "--threads") { if (++i >= argc) { invalid_param = true; @@ -1974,18 +2021,21 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->f_norm_rms_eps = std::stof(argv[i]); + params->custom_f_norm_rms_eps = true; } else if (arg == "--rope-freq-base") { if (++i >= argc) { invalid_param = true; break; } params->rope_freq_base = std::stof(argv[i]); + params->custom_rope_freq_base = true; } else if (arg == "--rope-freq-scale") { if (++i >= argc) { invalid_param = true; break; } params->rope_freq_scale = std::stof(argv[i]); + params->custom_rope_freq_scale = true; } else if (arg == "--lora-alpha") { if (++i >= argc) { invalid_param = true; @@ -2411,10 +2461,18 @@ int main(int argc, char ** argv) { opt->ctx = NULL; + load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams); + // set lora params from command line - lora.hparams.f_norm_rms_eps = params.f_norm_rms_eps; - lora.hparams.rope_freq_base = params.rope_freq_base; - lora.hparams.rope_freq_scale = params.rope_freq_scale; + if (params.custom_f_norm_rms_eps) { + lora.hparams.f_norm_rms_eps = params.f_norm_rms_eps; + } + if (params.custom_rope_freq_base) { + lora.hparams.rope_freq_base = params.rope_freq_base; + } + if (params.custom_rope_freq_scale) { + lora.hparams.rope_freq_scale = params.rope_freq_scale; + } lora.hparams.lora_r = params.lora_r; lora.hparams.lora_alpha = params.lora_alpha; lora.hparams.n_rank_attention_norm = params.n_rank_attention_norm; @@ -2465,7 +2523,10 @@ int main(int argc, char ** argv) { bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, opt); if (existed) { - model.hparams.n_ctx = params.n_ctx; + // overwrite last n_ctx with user provided n_ctx + if (params.custom_n_ctx) { + model.hparams.n_ctx = params.n_ctx; + } const bool opt_param_count_changed = ( (lora.hparams.n_rank_attention_norm != params.n_rank_attention_norm)