Various cleanups

Add informational output when overrides are applied

Warn user when an override with the wrong type is specified
This commit is contained in:
KerfuffleV2 2023-11-18 02:24:52 -07:00
parent dd89015c13
commit cb5bfe0c18

View file

@ -1749,46 +1749,71 @@ namespace GGUFMeta {
return GKV::getter(ctx, k);
}
static const char * override_type_to_str(const llama_model_kv_override_type ty) {
switch (ty) {
case LLAMA_KV_OVERRIDE_BOOL: return "bool";
case LLAMA_KV_OVERRIDE_INT: return "int";
case LLAMA_KV_OVERRIDE_FLOAT: return "float";
}
return "unknown";
}
static bool validate_override(const llama_model_kv_override_type expected_type, const struct llama_model_kv_override *override) {
if (!override) { return false; }
if (override->tag == expected_type) {
LLAMA_LOG_INFO("%s: Using metadata override (%5s) '%s' = ",
__func__, override_type_to_str(override->tag), override->key);
switch (override->tag) {
case LLAMA_KV_OVERRIDE_BOOL: {
printf("%s\n", override->bool_value ? "true" : "false");
} break;
case LLAMA_KV_OVERRIDE_INT: {
printf("%" PRId64 "\n", override->int_value);
} break;
case LLAMA_KV_OVERRIDE_FLOAT: {
printf("%.6f\n", override->float_value);
} break;
default:
// Shouldn't be possible to end up here, but just in case...
throw std::runtime_error(
format("Unsupported attempt to override %s type for metadata key %s\n",
override_type_to_str(override->tag), override->key));
}
return true;
}
LLAMA_LOG_WARN("%s: Warning: Bad metadata override type for key '%s', expected %s but got %s\n",
__func__, override->key, override_type_to_str(expected_type), override_type_to_str(override->tag));
return false;
}
template<typename OT>
static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override *override) {
if (!override) {
return false;
if (validate_override(LLAMA_KV_OVERRIDE_BOOL, override)) {
target = override->bool_value;
return true;
}
if (override->tag != LLAMA_KV_OVERRIDE_BOOL) {
return false;
}
target = override->bool_value;
return true;
}
template<typename OT>
static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override *override) {
if (!override) {
return false;
if (validate_override(LLAMA_KV_OVERRIDE_INT, override)) {
target = override->int_value;
return true;
}
if (override->tag != LLAMA_KV_OVERRIDE_INT) {
return false;
}
if (override->int_value < 0 && !std::is_signed<T>::value) {
return false;
}
target = override->int_value;
return true;
return false;
}
template<typename OT>
static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
try_override(T & target, const struct llama_model_kv_override *override) {
if (!override) {
return false;
if (validate_override(LLAMA_KV_OVERRIDE_FLOAT, override)) {
target = override->float_value;
return true;
}
if (override->tag != LLAMA_KV_OVERRIDE_FLOAT) {
return false;
}
target = override->float_value;
return true;
return false;
}
template<typename OT>
@ -1796,7 +1821,10 @@ namespace GGUFMeta {
try_override(T & target, const struct llama_model_kv_override *override) {
(void)target;
(void)override;
return false; // cannot override str
if (!override) { return false; }
// Currently, we should never end up here so it would be a bug if we do.
throw std::runtime_error(format("Unsupported attempt to override string type for metadata key %s\n",
override ? override->key : "NULL"));
}
static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override *override = nullptr) {
@ -1807,18 +1835,16 @@ namespace GGUFMeta {
return true;
}
template <typename TT>
static bool set(const gguf_context * ctx, const char * key, TT & target, const struct llama_model_kv_override *override = nullptr) {
static bool set(const gguf_context * ctx, const char * key, T & target, const struct llama_model_kv_override *override = nullptr) {
const int kid = gguf_find_key(ctx, key);
if (kid < 0) {
return false;
}
return GKV<TT>::set(ctx, kid, target, override);
return set(ctx, kid, target, override);
}
template <typename TT>
static bool set(const gguf_context * ctx, const std::string & key, TT & target, const struct llama_model_kv_override *override = nullptr) {
return GKV<TT>::set(ctx, key.c_str(), target, override);
static bool set(const gguf_context * ctx, const std::string & key, T & target, const struct llama_model_kv_override *override = nullptr) {
return set(ctx, key.c_str(), target, override);
}
};
}
@ -1967,8 +1993,9 @@ struct llama_model_loader {
}
}
template<typename T> typename std::enable_if<std::is_integral<T>::value, bool>::type
get_arr_n(const std::string & key, T & result, const bool required = false) {
template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type
get_arr_n(const std::string & key, T & result, const bool required = true) {
const int kid = gguf_find_key(ctx_gguf, key.c_str());
if (kid < 0) {
@ -1986,7 +2013,8 @@ struct llama_model_loader {
return true;
}
template<typename T> typename std::enable_if<std::is_integral<T>::value, bool>::type
template<typename T>
typename std::enable_if<std::is_integral<T>::value, bool>::type
get_arr_n(const enum llm_kv kid, T & result, const bool required = true) {
return get_arr_n(llm_kv(kid), result, required);
}
@ -8228,7 +8256,7 @@ static int llama_apply_lora_from_file_internal(
std::vector<uint8_t> base_buf;
if (path_base_model) {
LLAMA_LOG_INFO("%s: loading base model from '%s'\n", __func__, path_base_model);
ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/NULL));
ml.reset(new llama_model_loader(path_base_model, /*use_mmap*/ true, /*kv_overrides*/ NULL));
size_t ctx_size;
size_t mmapped_size;