From 1dd7aa9b1c1d7408b85a1d0b105d95e828496281 Mon Sep 17 00:00:00 2001 From: Jared Van Bortel Date: Fri, 2 Feb 2024 10:27:59 -0500 Subject: [PATCH] YaRN : store rope scaling type as int32_t in memory --- common/common.h | 3 +-- llama.cpp | 8 ++++---- llama.h | 2 +- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/common/common.h b/common/common.h index 24a99d728..62de25d6a 100644 --- a/common/common.h +++ b/common/common.h @@ -75,8 +75,7 @@ struct gpt_params { float yarn_beta_fast = 32.0f; // YaRN low correction dim float yarn_beta_slow = 1.0f; // YaRN high correction dim int32_t yarn_orig_ctx = 0; // YaRN original context length - int8_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // TODO: better to be int32_t for alignment - // pinging @cebtenzzre + int32_t rope_scaling_type = LLAMA_ROPE_SCALING_UNSPECIFIED; // // sampling parameters struct llama_sampling_params sparams; diff --git a/llama.cpp b/llama.cpp index 6bf7f9efb..aaee6f7d0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -747,13 +747,13 @@ struct LLM_TN { // gguf helpers // -static std::map LLAMA_ROPE_SCALING_TYPES = { +static std::map LLAMA_ROPE_SCALING_TYPES = { { LLAMA_ROPE_SCALING_NONE, "none" }, { LLAMA_ROPE_SCALING_LINEAR, "linear" }, { LLAMA_ROPE_SCALING_YARN, "yarn" }, }; -static int8_t llama_rope_scaling_type_from_string(const std::string & name) { +static int32_t llama_rope_scaling_type_from_string(const std::string & name) { for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { if (kv.second == name) { return kv.first; @@ -1415,6 +1415,7 @@ static const size_t GiB = 1024*MiB; struct llama_hparams { bool vocab_only; + bool rope_finetuned; uint32_t n_vocab; uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; @@ -1434,8 +1435,7 @@ struct llama_hparams { float rope_freq_base_train; float rope_freq_scale_train; uint32_t n_yarn_orig_ctx; - int8_t rope_scaling_type_train : 3; - bool rope_finetuned : 1; + int32_t rope_scaling_type_train; float f_clamp_kqv; float f_max_alibi_bias; diff --git a/llama.h b/llama.h index 9a60e9bfb..cec4158bc 100644 --- a/llama.h +++ b/llama.h @@ -213,7 +213,7 @@ extern "C" { uint32_t n_batch; // prompt processing maximum batch size uint32_t n_threads; // number of threads to use for generation uint32_t n_threads_batch; // number of threads to use for batch processing - int8_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` + int32_t rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model