diff --git a/llama.cpp b/llama.cpp index a29f5b419..742b8b86b 100644 --- a/llama.cpp +++ b/llama.cpp @@ -75,6 +75,7 @@ #include #include #include +#include #include #if defined(_MSC_VER) @@ -1682,6 +1683,161 @@ static std::string llama_format_tensor_shape(const struct ggml_tensor * t) { return buf; } +namespace GGUFMeta { + template + struct GKV_Base_Type { + static constexpr gguf_type gt = gt_; + + static T getter(const gguf_context * ctx, const int kid) { + return gfun(ctx, kid); + } + }; + + template struct GKV_Base; + + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + template<> struct GKV_Base: GKV_Base_Type {}; + + struct GetArrayLen{int value;}; + template<> struct GKV_Base { + public: + static constexpr gguf_type gt = GGUF_TYPE_ARRAY; + static GetArrayLen getter(const gguf_context *ctx, const int k) { + return GetArrayLen{gguf_get_arr_n(ctx, k)}; + } + }; + + struct ArrayInfo{ + const gguf_type gt; + const size_t length; + const void * data; + }; + + template<> struct GKV_Base { + public: + static constexpr gguf_type gt = GGUF_TYPE_ARRAY; + static ArrayInfo getter(const gguf_context *ctx, const int k) { + return ArrayInfo { + gguf_get_arr_type(ctx, k), + size_t(gguf_get_arr_n(ctx, k)), + gguf_get_arr_data(ctx, k), + }; + } + }; + + template + class GKV: public GKV_Base { + GKV() = delete; + + public: + static T get_kv(const gguf_context * ctx, const int k) { + const enum gguf_type kt = gguf_get_kv_type(ctx, k); + + if (kt != GKV::gt) { + throw std::runtime_error(format("key %s has wrong type %s but expected type %s", + gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt))); + } + return GKV::getter(ctx, k); + } + + // This can't be uncommented. + // template static bool try_override(OT & target, const struct llama_model_kv_override *override) = delete; + + template + static typename std::enable_if::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override *override) { + if (!override) { + return false; + } + if (override->tag != LLAMA_KV_OVERRIDE_BOOL) { + return false; + } + target = override->bool_value; + return true; + } + + template + static typename std::enable_if::value && std::is_integral::value, bool>::type + try_override(OT & target, const struct llama_model_kv_override *override) { + if (!override) { + return false; + } + if (override->tag != LLAMA_KV_OVERRIDE_INT) { + return false; + } + if (override->int_value < 0 && !std::is_signed::value) { + return false; + } + target = override->int_value; + return true; + } + + template + static typename std::enable_if::value, bool>::type + try_override(T & target, const struct llama_model_kv_override *override) { + if (!override) { + return false; + } + if (override->tag != LLAMA_KV_OVERRIDE_FLOAT) { + return false; + } + target = override->float_value; + return true; + } + + static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override *override = nullptr) { + if (try_override(target, override)) { + return true; + } + target = get_kv(ctx, k); + return true; + } + + template + static bool set(const gguf_context * ctx, const char * key, TT & target, const struct llama_model_kv_override *override = nullptr) { + const int kid = gguf_find_key(ctx, key); + if (kid < 0) { + return false; + } + return GKV::set(ctx, kid, target, override); + } + + template + static bool set(const gguf_context * ctx, const std::string & key, TT & target, const struct llama_model_kv_override *override = nullptr) { + return GKV::set(ctx, key.c_str(), target, override); + } + }; + + template<> + class GKV: public GKV_Base { + using BT = const char *; + public: + static bool set(const gguf_context * ctx, const int k, std::string & target, const struct llama_model_kv_override *override = nullptr) { + (void)override; + target = std::string(GKV::get_kv(ctx, k)); + return true; + } + + static bool set(const gguf_context * ctx, const char * key, std::string & target, const struct llama_model_kv_override *override = nullptr) { + return GKV::set(ctx, key, target, override); + } + + static bool set(const gguf_context * ctx, const std::string & key, std::string & target, const struct llama_model_kv_override *override = nullptr) { + return GKV::set(ctx, key, target, override); + } + }; +} + struct llama_model_loader { int n_kv = 0; int n_tensors = 0; @@ -1826,117 +1982,10 @@ struct llama_model_loader { } } - private: - template struct gk_get_arrlen { T & output; }; - template struct gk_set_literal { TI & input; TO & output; }; - template - void gk_set(int kid, T & result) { - (void)result; - throw std::runtime_error(format("request for key id %d with unhandled result type: %s", kid, typeid(T).name())); - } - void gk_set(const int k, uint8_t & r) { r = gguf_get_val_u8 (ctx_gguf, k); } - void gk_set(const int k, uint16_t & r) { r = gguf_get_val_u16 (ctx_gguf, k); } - void gk_set(const int k, uint32_t & r) { r = gguf_get_val_u32 (ctx_gguf, k); } - void gk_set(const int k, uint64_t & r) { r = gguf_get_val_u64 (ctx_gguf, k); } - void gk_set(const int k, int8_t & r) { r = gguf_get_val_i8 (ctx_gguf, k); } - void gk_set(const int k, int16_t & r) { r = gguf_get_val_i16 (ctx_gguf, k); } - void gk_set(const int k, int32_t & r) { r = gguf_get_val_i32 (ctx_gguf, k); } - void gk_set(const int k, int64_t & r) { r = gguf_get_val_i64 (ctx_gguf, k); } - void gk_set(const int k, float & r) { r = gguf_get_val_f32 (ctx_gguf, k); } - void gk_set(const int k, double & r) { r = gguf_get_val_f64 (ctx_gguf, k); } - void gk_set(const int k, bool & r) { r = gguf_get_val_bool(ctx_gguf, k); } - void gk_set(const int k, std::string & r) { r = std::string(gguf_get_val_str(ctx_gguf, k)); } - - template - typename std::enable_if::value, void>::type - gk_set(const int k, struct gk_get_arrlen r) { r.output = gguf_get_arr_n(ctx_gguf, k); } - - template - void gk_set_lit(const TI i, TO o) { - (void)i; (void)o; - throw std::runtime_error(format("gk_set_lit can't handle types: in=%s, out=%s", - typeid(TI).name(), typeid(TO).name())); - } - - template - typename std::enable_if::value, void>::type - gk_set_lit(const int64_t i, T & o) { o = T(i); } - - template - typename std::enable_if::value, void>::type - gk_set_lit(const double i, T & o) { o = T(i); } - - template - void gk_set_lit(const T & i, T & o) { o = i; } - - public: - template - bool get_key(const std::string & key, T & result, const bool required = false) { - enum gguf_type gt = GGUF_TYPE_COUNT; - enum llama_model_kv_override_type ot = LLAMA_KV_OVERRIDE_INT; - bool is_signed = false, can_override = true; - if (std::is_same::value) { - gt = GGUF_TYPE_UINT8; - } else if (std::is_same::value) { - gt = GGUF_TYPE_UINT16; - } else if (std::is_same::value) { - gt = GGUF_TYPE_UINT32; - } else if (std::is_same::value) { - gt = GGUF_TYPE_UINT64; - } else if (std::is_same::value) { - is_signed = true; - gt = GGUF_TYPE_INT8; - } else if (std::is_same::value) { - is_signed = true; - gt = GGUF_TYPE_INT16; - } else if (std::is_same::value) { - is_signed = true; - gt = GGUF_TYPE_INT32; - } else if (std::is_same::value) { - is_signed = true; - gt = GGUF_TYPE_INT64; - } else if (std::is_same::value) { - is_signed = true; - gt = GGUF_TYPE_FLOAT32; - ot = LLAMA_KV_OVERRIDE_FLOAT; - } else if (std::is_same::value) { - is_signed = true; - gt = GGUF_TYPE_FLOAT64; - ot = LLAMA_KV_OVERRIDE_FLOAT; - } else if (std::is_same::value) { - gt = GGUF_TYPE_BOOL; - ot = LLAMA_KV_OVERRIDE_BOOL; - } else if (std::is_same::value) { - can_override = false; - gt = GGUF_TYPE_STRING; - } else { - throw std::runtime_error(format("request for key '%s' with unknown result type: %s", key.c_str(), typeid(T).name())); - } - - if (can_override) { - auto it = kv_overrides.find(key); - if (it != kv_overrides.end()) { - struct llama_model_kv_override & po = it->second; - if (po.tag != ot) { - // Bad type - // FIXME: Error reporting - } else if (ot == LLAMA_KV_OVERRIDE_INT && po.int_value < 0 && !is_signed) { - // Out of range - // FIXME: Error reporting - } else { - // FIXME: Possible informational output - switch (po.tag) { - case LLAMA_KV_OVERRIDE_INT: gk_set_lit(po.int_value, result); break; - case LLAMA_KV_OVERRIDE_FLOAT: gk_set_lit(po.float_value, result); break; - case LLAMA_KV_OVERRIDE_BOOL: gk_set_lit(po.bool_value, result); break; - default: GGML_ASSERT(false && "Impossible: Unhandled override tag type"); - } - return true; - } - } - } - + template typename std::enable_if::value, bool>::type + get_arr_n(const std::string & key, T & result, const bool required = false) { const int kid = gguf_find_key(ctx_gguf, key.c_str()); + if (kid < 0) { if (required) { throw std::runtime_error(format("key not found in model: %s", key.c_str())); @@ -1944,23 +1993,37 @@ struct llama_model_loader { return false; } - const enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid); - if (ktype == GGUF_TYPE_ARRAY && ot == LLAMA_KV_OVERRIDE_INT) { - gk_get_arrlen arrlen = {result}; - gk_set(kid, arrlen); - return true; - } - if (ktype != gt) { - throw std::runtime_error(format("key %s has wrong type %s but expected type %s", - key.c_str(), gguf_type_name(ktype), gguf_type_name(gt))); - } - gk_set(kid, result); + struct GGUFMeta::ArrayInfo arr_info = + GGUFMeta::GKV::get_kv(ctx_gguf, kid); + + result = arr_info.length; return true; } + template typename std::enable_if::value, bool>::type + get_arr_n(const enum llm_kv kid, T & result, const bool required = true) { + return get_arr_n(llm_kv(kid), result, required); + } + template - bool get_key(const enum llm_kv kid, T & result, const bool required = false) { + bool get_key(const std::string & key, T & result, const bool required = true) { + auto it = kv_overrides.find(key); + + const struct llama_model_kv_override * override = + it != kv_overrides.end() ? &it->second : nullptr; + + const bool found = GGUFMeta::GKV::set(ctx_gguf, key, result, override); + + if (required && !found) { + throw std::runtime_error(format("key not found in model: %s", key.c_str())); + } + + return found; + } + + template + bool get_key(const enum llm_kv kid, T & result, const bool required = true) { return get_key(llm_kv(kid), result, required); } @@ -2222,12 +2285,12 @@ static void llm_load_hparams( ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); // get hparams kv - ml.get_key(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, true); - ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train, true); - ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd, true); - ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff, true); - ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head, true); - ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer, true); + ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab); + ml.get_key (LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); + ml.get_key (LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff); + ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head); + ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer); // n_head_kv is optional, default to n_head hparams.n_head_kv = hparams.n_head; @@ -2251,8 +2314,8 @@ static void llm_load_hparams( // rope_freq_scale (inverse of the kv) is optional float ropescale = 0.0f; - ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false); - if (ropescale == 0.0f) { // try the old key name + if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) { + // try the old key name ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false); } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; @@ -2276,7 +2339,7 @@ static void llm_load_hparams( switch (model.arch) { case LLM_ARCH_LLAMA: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 26: model.type = e_model::MODEL_3B; break; @@ -2290,7 +2353,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_FALCON: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { case 32: model.type = e_model::MODEL_7B; break; @@ -2300,7 +2363,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_BAICHUAN: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 32: model.type = e_model::MODEL_7B; break; case 40: model.type = e_model::MODEL_13B; break; @@ -2309,7 +2372,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_STARCODER: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_1B; break; case 36: model.type = e_model::MODEL_3B; break; @@ -2320,7 +2383,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_PERSIMMON: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { case 36: model.type = e_model::MODEL_8B; break; default: model.type = e_model::MODEL_UNKNOWN; @@ -2328,7 +2391,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_REFACT: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); switch (hparams.n_layer) { case 32: model.type = e_model::MODEL_1B; break; default: model.type = e_model::MODEL_UNKNOWN; @@ -2336,7 +2399,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_BLOOM: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { case 24: model.type = e_model::MODEL_1B; break; @@ -2351,9 +2414,9 @@ static void llm_load_hparams( { hparams.f_clamp_kqv = 0.0f; - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); - ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); - ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); + ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias); switch (hparams.n_layer) { case 32: model.type = e_model::MODEL_7B; break; @@ -2363,7 +2426,7 @@ static void llm_load_hparams( } break; case LLM_ARCH_STABLELM: { - ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); switch (hparams.n_layer) { case 32: model.type = e_model::MODEL_3B; break; @@ -2411,7 +2474,7 @@ static void llm_load_vocab( { std::string tokenizer_name; - ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name, true); + ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name); if (tokenizer_name == "llama") { vocab.type = LLAMA_VOCAB_TYPE_SPM;