gguf : store scaling type as a string instead of an int
This commit is contained in:
parent
4d5fe73449
commit
746641574a
2 changed files with 24 additions and 8 deletions
|
@ -410,10 +410,10 @@ class TokenType(IntEnum):
|
|||
UNUSED = 5
|
||||
BYTE = 6
|
||||
|
||||
class RopeScalingType(IntEnum):
|
||||
NONE = 0
|
||||
LINEAR = 1
|
||||
YARN = 2
|
||||
class RopeScalingType(Enum):
|
||||
NONE = 'none'
|
||||
LINEAR = 'linear'
|
||||
YARN = 'yarn'
|
||||
|
||||
#
|
||||
# implementation
|
||||
|
@ -769,7 +769,7 @@ class GGUFWriter:
|
|||
self.add_float32(KEY_ROPE_FREQ_BASE.format(arch=self.arch), value)
|
||||
|
||||
def add_rope_scaling_type(self, value: RopeScalingType):
|
||||
self.add_uint8(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), int(value))
|
||||
self.add_string(KEY_ROPE_SCALING_TYPE.format(arch=self.arch), value.value)
|
||||
|
||||
def add_rope_scaling_factor(self, value: float):
|
||||
self.add_float32(KEY_ROPE_SCALING_FACTOR.format(arch=self.arch), value)
|
||||
|
|
22
llama.cpp
22
llama.cpp
|
@ -470,6 +470,22 @@ struct LLM_TN {
|
|||
} \
|
||||
}
|
||||
|
||||
static std::map<llama_rope_scaling_type, std::string> LLAMA_ROPE_SCALING_TYPES = {
|
||||
{ LLAMA_ROPE_SCALING_NONE, "none" },
|
||||
{ LLAMA_ROPE_SCALING_LINEAR, "linear" },
|
||||
{ LLAMA_ROPE_SCALING_YARN, "yarn" },
|
||||
};
|
||||
|
||||
static llama_rope_scaling_type llama_rope_scaling_type_from_string(const std::string & name) {
|
||||
for (const auto & kv : LLAMA_ROPE_SCALING_TYPES) {
|
||||
if (kv.second == name) {
|
||||
return kv.first;
|
||||
}
|
||||
}
|
||||
|
||||
return LLAMA_ROPE_SCALING_UNSPECIFIED;
|
||||
}
|
||||
|
||||
//
|
||||
// ggml helpers
|
||||
//
|
||||
|
@ -1711,9 +1727,9 @@ static void llm_load_hparams(llama_model_loader & ml, llama_model & model, const
|
|||
int8_t rope_scaling_type = params.rope_scaling_type;
|
||||
|
||||
if (rope_scaling_type == LLAMA_ROPE_SCALING_UNSPECIFIED) {
|
||||
uint8_t type = LLAMA_ROPE_SCALING_LINEAR;
|
||||
GGUF_GET_KEY(ctx, type, gguf_get_val_u8, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE));
|
||||
rope_scaling_type = int8_t(type);
|
||||
std::string type("linear");
|
||||
GGUF_GET_KEY(ctx, type, gguf_get_val_str, GGUF_TYPE_UINT8, false, kv(LLM_KV_ROPE_SCALING_TYPE));
|
||||
rope_scaling_type = int8_t(llama_rope_scaling_type_from_string(type));
|
||||
}
|
||||
GGML_ASSERT(rope_scaling_type >= 0 && rope_scaling_type <= LLAMA_ROPE_SCALING_MAX_VALUE);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue