allow forcing ext_factor to zero if scaling type is YaRN

This commit is contained in:
Cebtenzzre 2023-10-07 13:20:33 -04:00
parent 4f4e94804d
commit 5d7a3a5c0d

View file

@ -54,6 +54,7 @@
#include <cassert>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <cstdarg>
#include <cstddef>
#include <cstdint>
@ -1735,7 +1736,7 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
// rope_freq_scale (inverse of the kv) is optional
if (rope_scaling_type == LLAMA_ROPE_SCALING_NONE) {
hparams.rope_freq_scale = 1.0f;
hparams.rope_freq_scale = 1.0f; // never scale if scaling type is none
} else if (hparams.rope_freq_scale == 0.0f) {
float ropescale = 0.0f;
GGUF_GET_KEY(ctx, ropescale, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_SCALING_FACTOR));
@ -1745,8 +1746,8 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
hparams.rope_freq_scale = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
}
if (rope_scaling_type == LLAMA_ROPE_SCALING_YARN) {
hparams.yarn_ext_factor = 1.0f; // enable YaRN
if (std::isnan(hparams.yarn_ext_factor)) { // NaN indicates 'not set'
hparams.yarn_ext_factor = rope_scaling_type == LLAMA_ROPE_SCALING_YARN ? 1.0f : 0.0f;
}
// sanity check for n_rot (optional)
@ -6268,7 +6269,7 @@ struct llama_context_params llama_context_default_params() {
/*.tensor_split =*/ nullptr,
/*.rope_freq_base =*/ 0.0f,
/*.rope_freq_scale =*/ 0.0f,
/*.yarn_ext_factor =*/ 0.0f,
/*.yarn_ext_factor =*/ NAN,
/*.yarn_attn_factor =*/ 1.0f,
/*.yarn_beta_fast =*/ 32.0f,
/*.yarn_beta_slow =*/ 1.0f,