diff --git a/llama.cpp b/llama.cpp index e42e5f5ef..45d23b3f1 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1749,46 +1749,71 @@ namespace GGUFMeta { return GKV::getter(ctx, k); } + static const char * override_type_to_str(const llama_model_kv_override_type ty) { + switch (ty) { + case LLAMA_KV_OVERRIDE_BOOL: return "bool"; + case LLAMA_KV_OVERRIDE_INT: return "int"; + case LLAMA_KV_OVERRIDE_FLOAT: return "float"; + } + return "unknown"; + } + + static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override *override) { + if (!override) { return false; } + if (override->tag == expected_type) { + LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ", + __func__, override_type_to_str(override->tag), override->key); + switch (override->tag) { + case LLAMA_KV_OVERRIDE_BOOL: { + printf("%s\n", override->bool_value ? "true" : "false"); + } break; + case LLAMA_KV_OVERRIDE_INT: { + printf("%" PRId64 "\n", override->int_value); + } break; + case LLAMA_KV_OVERRIDE_FLOAT: { + printf("%.6f\n", override->float_value); + } break; + default: + // Shouldn't be possible to end up here, but just in case... + throw std::runtime_error( + format("Unsupported attempt to override %s type for metadata key %s\n", + override_type_to_str(override->tag), override->key)); + } + return true; + } + LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n", + __func__, override->key, override_type_to_str(expected_type), override_type_to_str(override->tag)); + return false; + } + template static typename std::enable_if::value, bool>::type try_override(OT & target, const struct llama_model_kv_override *override) { - if (!override) { - return false; + if (validate_override(LLAMA_KV_OVERRIDE_BOOL, override)) { + target = override->bool_value; + return true; } - if (override->tag != LLAMA_KV_OVERRIDE_BOOL) { - return false; - } - target = override->bool_value; return true; } template static typename std::enable_if::value && std::is_integral::value, bool>::type try_override(OT & target, const struct llama_model_kv_override *override) { - if (!override) { - return false; + if (validate_override(LLAMA_KV_OVERRIDE_INT, override)) { + target = override->int_value; + return true; } - if (override->tag != LLAMA_KV_OVERRIDE_INT) { - return false; - } - if (override->int_value < 0 && !std::is_signed::value) { - return false; - } - target = override->int_value; - return true; + return false; } template static typename std::enable_if::value, bool>::type try_override(T & target, const struct llama_model_kv_override *override) { - if (!override) { - return false; + if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, override)) { + target = override->float_value; + return true; } - if (override->tag != LLAMA_KV_OVERRIDE_FLOAT) { - return false; - } - target = override->float_value; - return true; + return false; } template @@ -1796,7 +1821,10 @@ namespace GGUFMeta { try_override(T & target, const struct llama_model_kv_override *override) { (void)target; (void)override; - return false; // cannot override str + if (!override) { return false; } + // Currently, we should never end up here so it would be a bug if we do. + throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n", + override ? override->key : "NULL")); } static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override *override = nullptr) { @@ -1807,18 +1835,16 @@ namespace GGUFMeta { return true; } - template - static bool set(const gguf_context * ctx, const char * key, TT & target, const struct llama_model_kv_override *override = nullptr) { + static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) { const int kid = gguf_find_key(ctx, key); if (kid < 0) { return false; } - return GKV::set(ctx, kid, target, override); + return set(ctx, kid, target, override); } - template - static bool set(const gguf_context * ctx, const std::string & key, TT & target, const struct llama_model_kv_override *override = nullptr) { - return GKV::set(ctx, key.c_str(), target, override); + static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) { + return set(ctx, key.c_str(), target, override); } }; } @@ -1967,8 +1993,9 @@ struct llama_model_loader { } } - template typename std::enable_if::value, bool>::type - get_arr_n(const std::string & key, T & result, const bool required = false) { + template + typename std::enable_if::value, bool>::type + get_arr_n(const std::string & key, T & result, const bool required = true) { const int kid = gguf_find_key(ctx_gguf, key.c_str()); if (kid < 0) { @@ -1986,7 +2013,8 @@ struct llama_model_loader { return true; } - template typename std::enable_if::value, bool>::type + template + typename std::enable_if::value, bool>::type get_arr_n(const enum llm_kv kid, T & result, const bool required = true) { return get_arr_n(llm_kv(kid), result, required); } @@ -8228,7 +8256,7 @@ static int llama_apply_lora_from_file_internal( std::vector base_buf; if (path_base_model) { LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model); - ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/NULL)); + ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ NULL)); size_t ctx_size; size_t mmapped_size;