gguf : store scaling type as a string instead of an int

This commit is contained in:
Cebtenzzre 2023-10-07 12:57:55 -04:00
parent 4d5fe73449
commit 746641574a
2 changed files with 24 additions and 8 deletions

View file

@ -410,10 +410,10 @@ class TokenType(IntEnum):
UNUSED = 5 UNUSED = 5
BYTE = 6 BYTE = 6
class RopeScalingType(IntEnum): class RopeScalingType(Enum):
NONE = 0 NONE = 'none'
LINEAR = 1 LINEAR = 'linear'
YARN = 2 YARN = 'yarn'
# #
# implementation # implementation
@ -769,7 +769,7 @@ class GGUFWriter:
self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value) self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
def add_rope_scaling_type(self, value: RopeScalingType): def add_rope_scaling_type(self, value: RopeScalingType):
self.add_uint8(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), int(value)) self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value)
def add_rope_scaling_factor(self, value: float): def add_rope_scaling_factor(self, value: float):
self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value) self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)

View file

@ -470,6 +470,22 @@ struct LLM_TN {
} \ } \
} }
static std::map<llama_rope_scaling_type, std::string> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_NONE, "none" },
{ LLAMA_ROPE_SCALING_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_YARN, "yarn" },
};
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
return kv.first;
}
}
return LLAMA_ROPE_SCALING_UNSPECIFIED;
}
// //
// ggml helpers // ggml helpers
// //
@ -1711,9 +1727,9 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
int8_t rope_scaling_type = params.rope_scaling_type; int8_t rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) { if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
uint8_t type = LLAMA_ROPE_SCALING_LINEAR; std::string type("linear");
GGUF_GET_KEY(ctx, type, gguf_get_val_u8, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE)); GGUF_GET_KEY(ctx, type, gguf_get_val_str, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE));
rope_scaling_type = int8_t(type); rope_scaling_type = int8_t(llama_rope_scaling_type_from_string(type));
} }
GGML_ASSERT(rope_scaling_type >= 0 && rope_scaling_type <= LLAMA_ROPE_SCALING_MAX_VALUE); GGML_ASSERT(rope_scaling_type >= 0 && rope_scaling_type <= LLAMA_ROPE_SCALING_MAX_VALUE);