load default rms_norm and rope parameters from base model

This commit is contained in:
xaedes 2023-09-03 20:05:54 +02:00
parent bdb7092e82
commit 50589ed6be
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -383,6 +383,9 @@ void print_lora_params(struct my_llama_lora_hparams * params) {
printf("%s: n_rank_tok_embeddings : %u\n", __func__, params->n_rank_tok_embeddings);
printf("%s: n_rank_norm : %u\n", __func__, params->n_rank_norm);
printf("%s: n_rank_output : %u\n", __func__, params->n_rank_output);
printf("%s: norm_rms_eps : %f\n", __func__, params->f_norm_rms_eps);
printf("%s: rope_freq_base : %f\n", __func__, params->rope_freq_base);
printf("%s: rope_freq_scale : %f\n", __func__, params->rope_freq_scale);
}
void init_model(struct llama_model * input, struct my_llama_model * model, uint32_t n_ctx) {
@ -1238,6 +1241,37 @@ void read_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, co
}
}
void load_default_lora_params_from_base_model(const char * fn_base_model, struct my_llama_lora_hparams * lora_params) {
if (strlen(fn_base_model) == 0) {
return;
}
struct gguf_init_params params;
params.no_alloc = false;
params.ctx = NULL;
struct gguf_context * fctx = gguf_init_from_file(fn_base_model, params);
if (fctx == NULL) {
return;
}
const char * arch = "llama";
std::vector<char> keybuf;
keybuf.resize(512);
auto kv = [arch, &keybuf](const char * key) -> const char * {
snprintf(keybuf.data(), keybuf.size(), key, arch);
return keybuf.data();
};
float rope_freq_scale = 1.0f;
GGUF_GET_KEY(fctx, lora_params->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
GGUF_GET_KEY(fctx, lora_params->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
if (rope_freq_scale != 1.0f) {
lora_params->rope_freq_scale = 1.0f / rope_freq_scale;
}
gguf_free(fctx);
}
void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
// NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
@ -1697,12 +1731,18 @@ struct train_params {
int n_threads;
int n_batch;
bool custom_n_ctx;
bool only_write_lora;
float f_norm_rms_eps;
float rope_freq_base;
float rope_freq_scale;
bool custom_f_norm_rms_eps;
bool custom_rope_freq_base;
bool custom_rope_freq_scale;
int32_t lora_r;
int32_t lora_alpha;
@ -1765,12 +1805,18 @@ struct train_params get_default_train_params() {
params.n_threads = 6;
params.n_batch = 8;
params.custom_n_ctx = false;
params.only_write_lora = false;
params.f_norm_rms_eps = 1e-5f;
params.rope_freq_base = 10000.0f;
params.rope_freq_scale = 1.0f;
params.custom_f_norm_rms_eps = false;
params.custom_rope_freq_base = false;
params.custom_rope_freq_scale = false;
params.lora_alpha = 4;
params.lora_r = 4;
@ -1956,6 +2002,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break;
}
params->n_ctx = std::stoi(argv[i]);
params->custom_n_ctx = true;
} else if (arg == "-t" || arg == "--threads") {
if (++i >= argc) {
invalid_param = true;
@ -1974,18 +2021,21 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
break;
}
params->f_norm_rms_eps = std::stof(argv[i]);
params->custom_f_norm_rms_eps = true;
} else if (arg == "--rope-freq-base") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->rope_freq_base = std::stof(argv[i]);
params->custom_rope_freq_base = true;
} else if (arg == "--rope-freq-scale") {
if (++i >= argc) {
invalid_param = true;
break;
}
params->rope_freq_scale = std::stof(argv[i]);
params->custom_rope_freq_scale = true;
} else if (arg == "--lora-alpha") {
if (++i >= argc) {
invalid_param = true;
@ -2411,10 +2461,18 @@ int main(int argc, char ** argv) {
opt->ctx = NULL;
load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams);
// set lora params from command line
lora.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
lora.hparams.rope_freq_base = params.rope_freq_base;
lora.hparams.rope_freq_scale = params.rope_freq_scale;
if (params.custom_f_norm_rms_eps) {
lora.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
}
if (params.custom_rope_freq_base) {
lora.hparams.rope_freq_base = params.rope_freq_base;
}
if (params.custom_rope_freq_scale) {
lora.hparams.rope_freq_scale = params.rope_freq_scale;
}
lora.hparams.lora_r = params.lora_r;
lora.hparams.lora_alpha = params.lora_alpha;
lora.hparams.n_rank_attention_norm = params.n_rank_attention_norm;
@ -2465,7 +2523,10 @@ int main(int argc, char ** argv) {
bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, opt);
if (existed) {
model.hparams.n_ctx = params.n_ctx;
// overwrite last n_ctx with user provided n_ctx
if (params.custom_n_ctx) {
model.hparams.n_ctx = params.n_ctx;
}
const bool opt_param_count_changed = (
(lora.hparams.n_rank_attention_norm != params.n_rank_attention_norm)