llama : allow gguf rope keys to be overridden with defaults

This commit is contained in:
Cebtenzzre 2023-09-13 22:37:24 -04:00
parent 78c45b7975
commit 76988cdb9a

View file

@ -1673,28 +1673,19 @@ static void llm_load_hparams(
hparams.n_head_kv = hparams.n_head; hparams.n_head_kv = hparams.n_head;
GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); GGUF_GET_KEY(ctx, hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
// TODO: manually setting rope freq base and scale should override this
// FIXME: partial fix when the param specified is not the default value, but
// will not work for overriding the model value to the params default
llama_context_params defaults = llama_context_default_params(); llama_context_params defaults = llama_context_default_params();
// rope_freq_base // rope_freq_base (optional)
{ if (rope_freq_base == 0.0f) {
float ropebase = 10000.0f; rope_freq_base = 10000.0f;
GGUF_GET_KEY(ctx, ropebase, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); GGUF_GET_KEY(ctx, rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
if (ropebase != 10000.0f && rope_freq_base == defaults.rope_freq_base) {
rope_freq_base = ropebase;
}
} }
// rope_freq_scale (inverse of the kv) is optional // rope_freq_scale (inverse of the kv) is optional
{ if (rope_freq_scale == 0.0f) {
float ropescale = 1.0f; float ropescale = 1.0f;
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR)); GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALE_LINEAR));
if (ropescale != 1.0f && rope_freq_scale == defaults.rope_freq_scale) { rope_freq_scale = 1.0f/ropescale;
rope_freq_scale = 1.0f/ropescale;
}
} }
// sanity check for n_rot (optional) // sanity check for n_rot (optional)
@ -6187,8 +6178,8 @@ struct llama_context_params llama_context_default_params() {
/*.n_gpu_layers =*/ 0, /*.n_gpu_layers =*/ 0,
/*.main_gpu =*/ 0, /*.main_gpu =*/ 0,
/*.tensor_split =*/ nullptr, /*.tensor_split =*/ nullptr,
/*.rope_freq_base =*/ 10000.0f, /*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 1.0f, /*.rope_freq_scale =*/ 0.0f,
/*.progress_callback =*/ nullptr, /*.progress_callback =*/ nullptr,
/*.progress_callback_user_data =*/ nullptr, /*.progress_callback_user_data =*/ nullptr,
/*.low_vram =*/ false, /*.low_vram =*/ false,