Fix YaRN inverted scaling and add "rope.scaling.type" to GGUF (#1)

This commit is contained in:
Jeffrey Quesnelle 2023-10-19 19:36:16 -07:00 committed by GitHub
parent 9ef91b13ea
commit 9ae10b3aee
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 8 additions and 7 deletions

View file

@ -234,6 +234,7 @@ class Params:
n_head_kv = config.get("num_key_value_heads", n_head),
f_norm_eps = config["rms_norm_eps"],
f_rope_freq_base = config.get("rope_theta"),
rope_scaling_type = rope_scaling_type,
f_rope_scale = f_rope_scale,
n_orig_ctx = n_orig_ctx,
rope_finetuned = rope_finetuned,

View file

@ -4429,8 +4429,8 @@ static __device__ void rope_yarn(
}
// Get n-d magnitude scaling corrected for interpolation
if (freq_scale > 1.0f)
mscale *= 1.0f + 0.1f * logf(freq_scale);
if (freq_scale < 1.0f)
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}

View file

@ -899,8 +899,8 @@ static void rope_yarn(
}
// Get n-d magnitude scaling corrected for interpolation
if (freq_scale > 1.0f)
mscale *= 1.0f + 0.1f * logf(freq_scale);
if (freq_scale < 1.0f)
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}

4
ggml.c
View file

@ -13364,8 +13364,8 @@ static void rope_yarn(
}
// Get n-d magnitude scaling corrected for interpolation
if (freq_scale > 1.0f)
mscale *= 1.0f + 0.1f * logf(freq_scale);
if (freq_scale < 1.0f)
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}

View file

@ -2055,7 +2055,7 @@ static void llm_load_hparams(
GGUF_GET_KEY(ctx, hparams.rope_freq_base_train, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
std::string rope_scaling("linear");
GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE));
GGUF_GET_KEY(ctx, rope_scaling, gguf_get_val_str, GGUF_TYPE_STRING, false, kv(LLM_KV_ROPE_SCALING_TYPE));
hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling);
GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_UNSPECIFIED);