diff --git a/convert.py b/convert.py index 6e294b503..175f4b14b 100755 --- a/convert.py +++ b/convert.py @@ -234,6 +234,7 @@ class Params: n_head_kv = config.get("num_key_value_heads", n_head), f_norm_eps = config["rms_norm_eps"], f_rope_freq_base = config.get("rope_theta"), + rope_scaling_type = rope_scaling_type, f_rope_scale = f_rope_scale, n_orig_ctx = n_orig_ctx, rope_finetuned = rope_finetuned, diff --git a/ggml-cuda.cu b/ggml-cuda.cu index 95e1ae4c6..ff7b1e90a 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4429,8 +4429,8 @@ static __device__ void rope_yarn( } // Get n-d magnitude scaling corrected for interpolation - if (freq_scale > 1.0f) - mscale *= 1.0f + 0.1f * logf(freq_scale); + if (freq_scale < 1.0f) + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); *cos_theta = cosf(theta) * mscale; *sin_theta = sinf(theta) * mscale; } diff --git a/ggml-metal.metal b/ggml-metal.metal index 6fd3f9aa0..2064884ff 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -899,8 +899,8 @@ static void rope_yarn( } // Get n-d magnitude scaling corrected for interpolation - if (freq_scale > 1.0f) - mscale *= 1.0f + 0.1f * logf(freq_scale); + if (freq_scale < 1.0f) + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); *cos_theta = cosf(theta) * mscale; *sin_theta = sinf(theta) * mscale; } diff --git a/ggml.c b/ggml.c index 4b40a4e71..a24341810 100644 --- a/ggml.c +++ b/ggml.c @@ -13364,8 +13364,8 @@ static void rope_yarn( } // Get n-d magnitude scaling corrected for interpolation - if (freq_scale > 1.0f) - mscale *= 1.0f + 0.1f * logf(freq_scale); + if (freq_scale < 1.0f) + mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); *cos_theta = cosf(theta) * mscale; *sin_theta = sinf(theta) * mscale; } diff --git a/llama.cpp b/llama.cpp index faeee0d3a..cbab5f580 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2055,7 +2055,7 @@ static void llm_load_hparams( GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE)); std::string rope_scaling("linear"); - GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE)); + GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE)); hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);