From 14cf93b14c3e5160b383136f09bbb1344c1bf0ba Mon Sep 17 00:00:00 2001 From: Jeffrey Quesnelle Date: Fri, 20 Oct 2023 06:18:17 -0700 Subject: [PATCH] fix YaRN ramp, make mscale conditional, add --yarn-orig-ctx (#2) --- common/common.cpp | 8 ++++++++ common/common.h | 5 +++-- ggml-cuda.cu | 7 +++---- ggml-metal.metal | 7 +++---- ggml.c | 7 +++---- llama.cpp | 10 ++++++---- llama.h | 13 +++++++------ 7 files changed, 33 insertions(+), 24 deletions(-) diff --git a/common/common.cpp b/common/common.cpp index 3fafdfb38..d0b05c1ba 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -220,6 +220,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.rope_freq_scale = 1.0f/std::stof(argv[i]); + } else if (arg == "--yarn-orig-ctx") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.yarn_orig_ctx = std::stoi(argv[i]); } else if (arg == "--yarn-ext-factor") { if (++i >= argc) { invalid_param = true; @@ -737,6 +743,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { printf(" --rope-scale N RoPE context scaling factor, expands context by a factor of N\n"); printf(" --rope-freq-base N RoPE base frequency, used by NTK-aware scaling (default: loaded from model)\n"); printf(" --rope-freq-scale N RoPE frequency scaling factor, expands context by a factor of 1/N\n"); + printf(" --yarn-orig-ctx N YaRN: original context size of model (default: 0 = model training context size)\n"); printf(" --yarn-ext-factor N YaRN: extrapolation mix factor (default: 1.0, 0.0 = full interpolation)\n"); printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); @@ -861,6 +868,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_attn_factor = params.yarn_attn_factor; cparams.yarn_beta_fast = params.yarn_beta_fast; cparams.yarn_beta_slow = params.yarn_beta_slow; + cparams.yarn_orig_ctx = params.yarn_orig_ctx; return cparams; } diff --git a/common/common.h b/common/common.h index 91993dba1..01c2661b0 100644 --- a/common/common.h +++ b/common/common.h @@ -57,8 +57,9 @@ struct gpt_params { float rope_freq_scale = 0.0f; // RoPE frequency scaling factor float yarn_ext_factor = NAN; // YaRN extrapolation mix factor float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor - float yarn_beta_fast = 32.0f; // YaRN low correction dim - float yarn_beta_slow = 1.0f; // YaRN high correction dim + float yarn_beta_fast = 32.0f;// YaRN low correction dim + float yarn_beta_slow = 1.0f; // YaRN high correction dim + int32_t yarn_orig_ctx = 0; // YaRN original context length int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // // sampling parameters diff --git a/ggml-cuda.cu b/ggml-cuda.cu index ff7b1e90a..4c6a36ca1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -4406,7 +4406,7 @@ static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne, } static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) { - const float y = (i0 / 2 - low) / min(0.001f, high - low); + const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); } @@ -4426,11 +4426,10 @@ static __device__ void rope_yarn( if (ext_factor != 0.0f) { float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor; theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - } - // Get n-d magnitude scaling corrected for interpolation - if (freq_scale < 1.0f) + // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } *cos_theta = cosf(theta) * mscale; *sin_theta = sinf(theta) * mscale; } diff --git a/ggml-metal.metal b/ggml-metal.metal index 2064884ff..427291774 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -880,7 +880,7 @@ kernel void kernel_alibi_f32( } static float rope_yarn_ramp(const float low, const float high, const int i0) { - const float y = (i0 / 2 - low) / min(0.001f, high - low); + const float y = (i0 / 2 - low) / max(0.001f, high - low); return 1.0f - min(1.0f, max(0.0f, y)); } @@ -896,11 +896,10 @@ static void rope_yarn( if (ext_factor != 0.0f) { ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - } - // Get n-d magnitude scaling corrected for interpolation - if (freq_scale < 1.0f) + // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } *cos_theta = cosf(theta) * mscale; *sin_theta = sinf(theta) * mscale; } diff --git a/ggml.c b/ggml.c index a24341810..111ee3e56 100644 --- a/ggml.c +++ b/ggml.c @@ -13345,7 +13345,7 @@ static void ggml_compute_forward_clamp( // ggml_compute_forward_rope static float rope_yarn_ramp(const float low, const float high, const int i0) { - const float y = (i0 / 2 - low) / MIN(0.001f, high - low); + const float y = (i0 / 2 - low) / MAX(0.001f, high - low); return 1 - MIN(1, MAX(0, y)); } @@ -13361,11 +13361,10 @@ static void rope_yarn( if (ext_factor != 0.0f) { float ramp_mix = rope_yarn_ramp(corr_dims[0], corr_dims[1], i0) * ext_factor; theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; - } - // Get n-d magnitude scaling corrected for interpolation - if (freq_scale < 1.0f) + // Get n-d magnitude scaling corrected for interpolation mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale); + } *cos_theta = cosf(theta) * mscale; *sin_theta = sinf(theta) * mscale; } diff --git a/llama.cpp b/llama.cpp index cbab5f580..01e219a48 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1113,6 +1113,7 @@ struct llama_cparams { float rope_freq_base; float rope_freq_scale; + uint32_t n_yarn_orig_ctx; // These hyperparameters are not exposed in GGUF, because all // existing YaRN models use the same values for them. float yarn_ext_factor; @@ -3028,7 +3029,7 @@ static struct ggml_cgraph * llm_build_llama( const int32_t n_embd = hparams.n_embd; const int32_t n_layer = hparams.n_layer; const int32_t n_ctx = cparams.n_ctx; - const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx; + const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx; const int32_t n_head = hparams.n_head; const int32_t n_head_kv = hparams.n_head_kv; const int32_t n_embd_head = hparams.n_embd_head(); @@ -3430,7 +3431,7 @@ static struct ggml_cgraph * llm_build_baichaun( const int32_t n_embd = hparams.n_embd; const int32_t n_layer = hparams.n_layer; const int32_t n_ctx = cparams.n_ctx; - const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx; + const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx; const int32_t n_head = hparams.n_head; const int32_t n_head_kv = hparams.n_head_kv; const int32_t n_embd_head = hparams.n_embd_head(); @@ -4194,7 +4195,7 @@ static struct ggml_cgraph * llm_build_falcon( const int32_t n_embd = hparams.n_embd; const int32_t n_layer = hparams.n_layer; const int32_t n_ctx = cparams.n_ctx; - const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx; + const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx; const int32_t n_head = hparams.n_head; const int32_t n_head_kv = hparams.n_head_kv; const int32_t n_embd_head = hparams.n_embd_head(); @@ -4818,7 +4819,7 @@ static struct ggml_cgraph * llm_build_persimmon( const int64_t n_embd = hparams.n_embd; const int64_t n_layer = hparams.n_layer; const int64_t n_ctx = cparams.n_ctx; - const int32_t n_orig_ctx = hparams.n_yarn_orig_ctx; + const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx; const int64_t n_head_kv = hparams.n_head_kv; const int64_t n_head = hparams.n_head; const int64_t n_embd_head = hparams.n_embd_head(); @@ -8676,6 +8677,7 @@ struct llama_context * llama_new_context_with_model( cparams.mul_mat_q = params.mul_mat_q; cparams.n_ctx = params.n_ctx == 0 ? hparams.n_ctx_train : params.n_ctx; + cparams.n_yarn_orig_ctx = params.yarn_orig_ctx == 0 ? hparams.n_ctx_train : params.yarn_orig_ctx; cparams.rope_freq_base = params.rope_freq_base == 0.0f ? hparams.rope_freq_base_train : params.rope_freq_base; cparams.rope_freq_scale = params.rope_freq_scale == 0.0f ? hparams.rope_freq_scale_train : params.rope_freq_scale; diff --git a/llama.h b/llama.h index 48e12cfea..5f6b14e19 100644 --- a/llama.h +++ b/llama.h @@ -182,12 +182,13 @@ extern "C" { int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` // ref: https://github.com/ggerganov/llama.cpp/pull/2054 - float rope_freq_base; // RoPE base frequency, 0 = from model - float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model - float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model - float yarn_attn_factor; // YaRN magnitude scaling factor - float yarn_beta_fast; // YaRN low correction dim - float yarn_beta_slow; // YaRN high correction dim + float rope_freq_base; // RoPE base frequency, 0 = from model + float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model + float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model + float yarn_attn_factor; // YaRN magnitude scaling factor + float yarn_beta_fast; // YaRN low correction dim + float yarn_beta_slow; // YaRN high correction dim + uint32_t yarn_orig_ctx; // YaRN original context size // Keep the booleans together to avoid misalignment during copy-by-value. bool mul_mat_q; // if true, use experimental mul_mat_q kernels