Fix YaRN inverted scaling and add "rope.scaling.type" to GGUF (#1)
This commit is contained in:
parent
9ef91b13ea
commit
9ae10b3aee
5 changed files with 8 additions and 7 deletions
|
@ -234,6 +234,7 @@ class Params:
|
|||
n_head_kv = config.get("num_key_value_heads", n_head),
|
||||
f_norm_eps = config["rms_norm_eps"],
|
||||
f_rope_freq_base = config.get("rope_theta"),
|
||||
rope_scaling_type = rope_scaling_type,
|
||||
f_rope_scale = f_rope_scale,
|
||||
n_orig_ctx = n_orig_ctx,
|
||||
rope_finetuned = rope_finetuned,
|
||||
|
|
|
@ -4429,8 +4429,8 @@ static __device__ void rope_yarn(
|
|||
}
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
if (freq_scale > 1.0f)
|
||||
mscale *= 1.0f + 0.1f * logf(freq_scale);
|
||||
if (freq_scale < 1.0f)
|
||||
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
*cos_theta = cosf(theta) * mscale;
|
||||
*sin_theta = sinf(theta) * mscale;
|
||||
}
|
||||
|
|
|
@ -899,8 +899,8 @@ static void rope_yarn(
|
|||
}
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
if (freq_scale > 1.0f)
|
||||
mscale *= 1.0f + 0.1f * logf(freq_scale);
|
||||
if (freq_scale < 1.0f)
|
||||
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
*cos_theta = cosf(theta) * mscale;
|
||||
*sin_theta = sinf(theta) * mscale;
|
||||
}
|
||||
|
|
4
ggml.c
4
ggml.c
|
@ -13364,8 +13364,8 @@ static void rope_yarn(
|
|||
}
|
||||
|
||||
// Get n-d magnitude scaling corrected for interpolation
|
||||
if (freq_scale > 1.0f)
|
||||
mscale *= 1.0f + 0.1f * logf(freq_scale);
|
||||
if (freq_scale < 1.0f)
|
||||
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
|
||||
*cos_theta = cosf(theta) * mscale;
|
||||
*sin_theta = sinf(theta) * mscale;
|
||||
}
|
||||
|
|
|
@ -2055,7 +2055,7 @@ static void llm_load_hparams(
|
|||
GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
|
||||
|
||||
std::string rope_scaling("linear");
|
||||
GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE));
|
||||
GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE));
|
||||
hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
|
||||
GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue