load default rms_norm and rope parameters from base model
This commit is contained in:
parent
bdb7092e82
commit
50589ed6be
1 changed files with 65 additions and 4 deletions
|
@ -383,6 +383,9 @@ void print_lora_params(struct my_llama_lora_hparams * params) {
|
||||||
printf("%s: n_rank_tok_embeddings : %u\n", __func__, params->n_rank_tok_embeddings);
|
printf("%s: n_rank_tok_embeddings : %u\n", __func__, params->n_rank_tok_embeddings);
|
||||||
printf("%s: n_rank_norm : %u\n", __func__, params->n_rank_norm);
|
printf("%s: n_rank_norm : %u\n", __func__, params->n_rank_norm);
|
||||||
printf("%s: n_rank_output : %u\n", __func__, params->n_rank_output);
|
printf("%s: n_rank_output : %u\n", __func__, params->n_rank_output);
|
||||||
|
printf("%s: norm_rms_eps : %f\n", __func__, params->f_norm_rms_eps);
|
||||||
|
printf("%s: rope_freq_base : %f\n", __func__, params->rope_freq_base);
|
||||||
|
printf("%s: rope_freq_scale : %f\n", __func__, params->rope_freq_scale);
|
||||||
}
|
}
|
||||||
|
|
||||||
void init_model(struct llama_model * input, struct my_llama_model * model, uint32_t n_ctx) {
|
void init_model(struct llama_model * input, struct my_llama_model * model, uint32_t n_ctx) {
|
||||||
|
@ -1238,6 +1241,37 @@ void read_tensor_by_name(struct ggml_tensor * dst, struct ggml_context * ctx, co
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void load_default_lora_params_from_base_model(const char * fn_base_model, struct my_llama_lora_hparams * lora_params) {
|
||||||
|
if (strlen(fn_base_model) == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
struct gguf_init_params params;
|
||||||
|
params.no_alloc = false;
|
||||||
|
params.ctx = NULL;
|
||||||
|
struct gguf_context * fctx = gguf_init_from_file(fn_base_model, params);
|
||||||
|
if (fctx == NULL) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const char * arch = "llama";
|
||||||
|
std::vector<char> keybuf;
|
||||||
|
keybuf.resize(512);
|
||||||
|
auto kv = [arch, &keybuf](const char * key) -> const char * {
|
||||||
|
snprintf(keybuf.data(), keybuf.size(), key, arch);
|
||||||
|
return keybuf.data();
|
||||||
|
};
|
||||||
|
|
||||||
|
float rope_freq_scale = 1.0f;
|
||||||
|
GGUF_GET_KEY(fctx, lora_params->f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
||||||
|
GGUF_GET_KEY(fctx, lora_params->rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
|
||||||
|
GGUF_GET_KEY(fctx, rope_freq_scale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
|
||||||
|
if (rope_freq_scale != 1.0f) {
|
||||||
|
lora_params->rope_freq_scale = 1.0f / rope_freq_scale;
|
||||||
|
}
|
||||||
|
|
||||||
|
gguf_free(fctx);
|
||||||
|
}
|
||||||
|
|
||||||
void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
|
void load_opt_context_gguf(struct gguf_context * fctx, struct ggml_context * f_ggml_ctx, struct ggml_opt_context * opt) {
|
||||||
// NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
|
// NOTE: gguf_context must be initialized with f_ggml_ctx and no_alloc=false, otherwise tensor data can not be read
|
||||||
|
|
||||||
|
@ -1697,12 +1731,18 @@ struct train_params {
|
||||||
int n_threads;
|
int n_threads;
|
||||||
int n_batch;
|
int n_batch;
|
||||||
|
|
||||||
|
bool custom_n_ctx;
|
||||||
|
|
||||||
bool only_write_lora;
|
bool only_write_lora;
|
||||||
|
|
||||||
float f_norm_rms_eps;
|
float f_norm_rms_eps;
|
||||||
float rope_freq_base;
|
float rope_freq_base;
|
||||||
float rope_freq_scale;
|
float rope_freq_scale;
|
||||||
|
|
||||||
|
bool custom_f_norm_rms_eps;
|
||||||
|
bool custom_rope_freq_base;
|
||||||
|
bool custom_rope_freq_scale;
|
||||||
|
|
||||||
int32_t lora_r;
|
int32_t lora_r;
|
||||||
int32_t lora_alpha;
|
int32_t lora_alpha;
|
||||||
|
|
||||||
|
@ -1765,12 +1805,18 @@ struct train_params get_default_train_params() {
|
||||||
params.n_threads = 6;
|
params.n_threads = 6;
|
||||||
params.n_batch = 8;
|
params.n_batch = 8;
|
||||||
|
|
||||||
|
params.custom_n_ctx = false;
|
||||||
|
|
||||||
params.only_write_lora = false;
|
params.only_write_lora = false;
|
||||||
|
|
||||||
params.f_norm_rms_eps = 1e-5f;
|
params.f_norm_rms_eps = 1e-5f;
|
||||||
params.rope_freq_base = 10000.0f;
|
params.rope_freq_base = 10000.0f;
|
||||||
params.rope_freq_scale = 1.0f;
|
params.rope_freq_scale = 1.0f;
|
||||||
|
|
||||||
|
params.custom_f_norm_rms_eps = false;
|
||||||
|
params.custom_rope_freq_base = false;
|
||||||
|
params.custom_rope_freq_scale = false;
|
||||||
|
|
||||||
params.lora_alpha = 4;
|
params.lora_alpha = 4;
|
||||||
params.lora_r = 4;
|
params.lora_r = 4;
|
||||||
|
|
||||||
|
@ -1956,6 +2002,7 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->n_ctx = std::stoi(argv[i]);
|
params->n_ctx = std::stoi(argv[i]);
|
||||||
|
params->custom_n_ctx = true;
|
||||||
} else if (arg == "-t" || arg == "--threads") {
|
} else if (arg == "-t" || arg == "--threads") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -1974,18 +2021,21 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->f_norm_rms_eps = std::stof(argv[i]);
|
params->f_norm_rms_eps = std::stof(argv[i]);
|
||||||
|
params->custom_f_norm_rms_eps = true;
|
||||||
} else if (arg == "--rope-freq-base") {
|
} else if (arg == "--rope-freq-base") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->rope_freq_base = std::stof(argv[i]);
|
params->rope_freq_base = std::stof(argv[i]);
|
||||||
|
params->custom_rope_freq_base = true;
|
||||||
} else if (arg == "--rope-freq-scale") {
|
} else if (arg == "--rope-freq-scale") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params->rope_freq_scale = std::stof(argv[i]);
|
params->rope_freq_scale = std::stof(argv[i]);
|
||||||
|
params->custom_rope_freq_scale = true;
|
||||||
} else if (arg == "--lora-alpha") {
|
} else if (arg == "--lora-alpha") {
|
||||||
if (++i >= argc) {
|
if (++i >= argc) {
|
||||||
invalid_param = true;
|
invalid_param = true;
|
||||||
|
@ -2411,10 +2461,18 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
opt->ctx = NULL;
|
opt->ctx = NULL;
|
||||||
|
|
||||||
|
load_default_lora_params_from_base_model(params.fn_model_base, &lora.hparams);
|
||||||
|
|
||||||
// set lora params from command line
|
// set lora params from command line
|
||||||
|
if (params.custom_f_norm_rms_eps) {
|
||||||
lora.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
|
lora.hparams.f_norm_rms_eps = params.f_norm_rms_eps;
|
||||||
|
}
|
||||||
|
if (params.custom_rope_freq_base) {
|
||||||
lora.hparams.rope_freq_base = params.rope_freq_base;
|
lora.hparams.rope_freq_base = params.rope_freq_base;
|
||||||
|
}
|
||||||
|
if (params.custom_rope_freq_scale) {
|
||||||
lora.hparams.rope_freq_scale = params.rope_freq_scale;
|
lora.hparams.rope_freq_scale = params.rope_freq_scale;
|
||||||
|
}
|
||||||
lora.hparams.lora_r = params.lora_r;
|
lora.hparams.lora_r = params.lora_r;
|
||||||
lora.hparams.lora_alpha = params.lora_alpha;
|
lora.hparams.lora_alpha = params.lora_alpha;
|
||||||
lora.hparams.n_rank_attention_norm = params.n_rank_attention_norm;
|
lora.hparams.n_rank_attention_norm = params.n_rank_attention_norm;
|
||||||
|
@ -2465,7 +2523,10 @@ int main(int argc, char ** argv) {
|
||||||
bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, opt);
|
bool existed = load_checkpoint_lora_file(params.fn_checkpoint_in, &model, &lora, opt);
|
||||||
|
|
||||||
if (existed) {
|
if (existed) {
|
||||||
|
// overwrite last n_ctx with user provided n_ctx
|
||||||
|
if (params.custom_n_ctx) {
|
||||||
model.hparams.n_ctx = params.n_ctx;
|
model.hparams.n_ctx = params.n_ctx;
|
||||||
|
}
|
||||||
|
|
||||||
const bool opt_param_count_changed = (
|
const bool opt_param_count_changed = (
|
||||||
(lora.hparams.n_rank_attention_norm != params.n_rank_attention_norm)
|
(lora.hparams.n_rank_attention_norm != params.n_rank_attention_norm)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue