diff --git a/gguf-py/gguf/gguf.py b/gguf-py/gguf/gguf.py index 742ce2143..6bcb33701 100644 --- a/gguf-py/gguf/gguf.py +++ b/gguf-py/gguf/gguf.py @@ -410,10 +410,10 @@ class TokenType(IntEnum): UNUSED = 5 BYTE = 6 -class RopeScalingType(IntEnum): - NONE = 0 - LINEAR = 1 - YARN = 2 +class RopeScalingType(Enum): + NONE = 'none' + LINEAR = 'linear' + YARN = 'yarn' # # implementation @@ -769,7 +769,7 @@ class GGUFWriter: self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value) def add_rope_scaling_type(self, value: RopeScalingType): - self.add_uint8(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), int(value)) + self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value) def add_rope_scaling_factor(self, value: float): self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value) diff --git a/llama.cpp b/llama.cpp index d862541d1..4f68ba692 100644 --- a/llama.cpp +++ b/llama.cpp @@ -470,6 +470,22 @@ struct LLM_TN { } \ } +static std::map LLAMA_ROPE_SCALING_TYPES = { + { LLAMA_ROPE_SCALING_NONE, "none" }, + { LLAMA_ROPE_SCALING_LINEAR, "linear" }, + { LLAMA_ROPE_SCALING_YARN, "yarn" }, +}; + +static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) { + for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) { + if (kv.second == name) { + return kv.first; + } + } + + return LLAMA_ROPE_SCALING_UNSPECIFIED; +} + // // ggml helpers // @@ -1711,9 +1727,9 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const int8_t rope_scaling_type = params.rope_scaling_type; if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) { - uint8_t type = LLAMA_ROPE_SCALING_LINEAR; - GGUF_GET_KEY(ctx, type, gguf_get_val_u8, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE)); - rope_scaling_type = int8_t(type); + std::string type("linear"); + GGUF_GET_KEY(ctx, type, gguf_get_val_str, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE)); + rope_scaling_type = int8_t(llama_rope_scaling_type_from_string(type)); } GGML_ASSERT(rope_scaling_type >= 0 && rope_scaling_type <= LLAMA_ROPE_SCALING_MAX_VALUE);