gguf : store scaling type as a string instead of an int

This commit is contained in:
Cebtenzzre 2023-10-07 12:57:55 -04:00
parent 4d5fe73449
commit 746641574a
2 changed files with 24 additions and 8 deletions

View file

@ -410,10 +410,10 @@ class TokenType(IntEnum):
UNUSED = 5
BYTE = 6
class RopeScalingType(IntEnum):
NONE = 0
LINEAR = 1
YARN = 2
class RopeScalingType(Enum):
NONE = 'none'
LINEAR = 'linear'
YARN = 'yarn'
#
# implementation
@ -769,7 +769,7 @@ class GGUFWriter:
self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
def add_rope_scaling_type(self, value: RopeScalingType):
self.add_uint8(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), int(value))
self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value)
def add_rope_scaling_factor(self, value: float):
self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)

View file

@ -470,6 +470,22 @@ struct LLM_TN {
} \
}
static std::map<llama_rope_scaling_type, std::string> LLAMA_ROPE_SCALING_TYPES = {
{ LLAMA_ROPE_SCALING_NONE, "none" },
{ LLAMA_ROPE_SCALING_LINEAR, "linear" },
{ LLAMA_ROPE_SCALING_YARN, "yarn" },
};
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
if (kv.second == name) {
return kv.first;
}
}
return LLAMA_ROPE_SCALING_UNSPECIFIED;
}
//
// ggml helpers
//
@ -1711,9 +1727,9 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
int8_t rope_scaling_type = params.rope_scaling_type;
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
uint8_t type = LLAMA_ROPE_SCALING_LINEAR;
GGUF_GET_KEY(ctx, type, gguf_get_val_u8, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE));
rope_scaling_type = int8_t(type);
std::string type("linear");
GGUF_GET_KEY(ctx, type, gguf_get_val_str, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE));
rope_scaling_type = int8_t(llama_rope_scaling_type_from_string(type));
}
GGML_ASSERT(rope_scaling_type >= 0 && rope_scaling_type <= LLAMA_ROPE_SCALING_MAX_VALUE);