This commit is contained in:
KerfuffleV2 2023-11-16 01:09:41 -07:00
parent 9d39deab8f
commit 69be5c3d6d

View file

@ -1834,26 +1834,25 @@ struct llama_model_loader {
(void)result; (void)result;
throw std::runtime_error(format("request for key id %d with unhandled result type: %s", kid, typeid(T).name())); throw std::runtime_error(format("request for key id %d with unhandled result type: %s", kid, typeid(T).name()));
} }
void gk_set(const int k, uint8_t & r) { r = gguf_get_val_u8 (ctx_gguf, k); }
void gk_set(int k, uint8_t & r) { r = gguf_get_val_u8 (ctx_gguf, k); } void gk_set(const int k, uint16_t & r) { r = gguf_get_val_u16 (ctx_gguf, k); }
void gk_set(int k, uint16_t & r) { r = gguf_get_val_u16 (ctx_gguf, k); } void gk_set(const int k, uint32_t & r) { r = gguf_get_val_u32 (ctx_gguf, k); }
void gk_set(int k, uint32_t & r) { r = gguf_get_val_u32 (ctx_gguf, k); } void gk_set(const int k, uint64_t & r) { r = gguf_get_val_u64 (ctx_gguf, k); }
void gk_set(int k, uint64_t & r) { r = gguf_get_val_u64 (ctx_gguf, k); } void gk_set(const int k, int8_t & r) { r = gguf_get_val_i8 (ctx_gguf, k); }
void gk_set(int k, int8_t & r) { r = gguf_get_val_i8 (ctx_gguf, k); } void gk_set(const int k, int16_t & r) { r = gguf_get_val_i16 (ctx_gguf, k); }
void gk_set(int k, int16_t & r) { r = gguf_get_val_i16 (ctx_gguf, k); } void gk_set(const int k, int32_t & r) { r = gguf_get_val_i32 (ctx_gguf, k); }
void gk_set(int k, int32_t & r) { r = gguf_get_val_i32 (ctx_gguf, k); } void gk_set(const int k, int64_t & r) { r = gguf_get_val_i64 (ctx_gguf, k); }
void gk_set(int k, int64_t & r) { r = gguf_get_val_i64 (ctx_gguf, k); } void gk_set(const int k, float & r) { r = gguf_get_val_f32 (ctx_gguf, k); }
void gk_set(int k, float & r) { r = gguf_get_val_f32 (ctx_gguf, k); } void gk_set(const int k, double & r) { r = gguf_get_val_f64 (ctx_gguf, k); }
void gk_set(int k, double & r) { r = gguf_get_val_f64 (ctx_gguf, k); } void gk_set(const int k, bool & r) { r = gguf_get_val_bool(ctx_gguf, k); }
void gk_set(int k, bool & r) { r = gguf_get_val_bool(ctx_gguf, k); } void gk_set(const int k, std::string & r) { r = std::string(gguf_get_val_str(ctx_gguf, k)); }
void gk_set(int k, std::string & r) { r = std::string(gguf_get_val_str(ctx_gguf, k)); }
template<typename T> template<typename T>
typename std::enable_if<std::is_integral<T>::value, void>::type typename std::enable_if<std::is_integral<T>::value, void>::type
gk_set(int k, struct gk_get_arrlen<T> & r) { r.output = gguf_get_arr_n(ctx_gguf, k); } gk_set(const int k, struct gk_get_arrlen<T> r) { r.output = gguf_get_arr_n(ctx_gguf, k); }
template<typename TI, typename TO> template<typename TI, typename TO>
void gk_set_lit(TI i, TO o) { void gk_set_lit(const TI i, TO o) {
(void)i; (void)o; (void)i; (void)o;
throw std::runtime_error(format("gk_set_lit can't handle types: in=%s, out=%s", throw std::runtime_error(format("gk_set_lit can't handle types: in=%s, out=%s",
typeid(TI).name(), typeid(TO).name())); typeid(TI).name(), typeid(TO).name()));
@ -1861,11 +1860,11 @@ struct llama_model_loader {
template<typename T> template<typename T>
typename std::enable_if<std::is_integral<T>::value, void>::type typename std::enable_if<std::is_integral<T>::value, void>::type
gk_set_lit(const int64_t & i, T & o) { o = T(i); } gk_set_lit(const int64_t i, T & o) { o = T(i); }
template<typename T> template<typename T>
typename std::enable_if<std::is_floating_point<T>::value, void>::type typename std::enable_if<std::is_floating_point<T>::value, void>::type
gk_set_lit(const double & i, T & o) { o = T(i); } gk_set_lit(const double i, T & o) { o = T(i); }
template<typename T> template<typename T>
void gk_set_lit(const T & i, T & o) { o = i; } void gk_set_lit(const T & i, T & o) { o = i; }
@ -1873,46 +1872,45 @@ struct llama_model_loader {
public: public:
template<typename T> template<typename T>
bool get_key(const std::string & key, T & result, const bool required = false) { bool get_key(const std::string & key, T & result, const bool required = false) {
const auto & tt = typeid(T);
enum gguf_type gt = GGUF_TYPE_COUNT; enum gguf_type gt = GGUF_TYPE_COUNT;
enum llama_model_kv_override_type ot = LLAMA_KV_OVERRIDE_INT; enum llama_model_kv_override_type ot = LLAMA_KV_OVERRIDE_INT;
bool is_signed = false, can_override = true; bool is_signed = false, can_override = true;
if (tt == typeid(uint8_t)) { if (std::is_same<T, uint8_t>::value) {
gt = GGUF_TYPE_UINT8; gt = GGUF_TYPE_UINT8;
} else if (tt == typeid(uint16_t)) { } else if (std::is_same<T, uint16_t>::value) {
gt = GGUF_TYPE_UINT16; gt = GGUF_TYPE_UINT16;
} else if (tt == typeid(uint32_t)) { } else if (std::is_same<T, uint32_t>::value) {
gt = GGUF_TYPE_UINT32; gt = GGUF_TYPE_UINT32;
} else if (tt == typeid(uint64_t)) { } else if (std::is_same<T, uint64_t>::value) {
gt = GGUF_TYPE_UINT64; gt = GGUF_TYPE_UINT64;
} else if (tt == typeid(int8_t)) { } else if (std::is_same<T, int8_t>::value) {
is_signed = true; is_signed = true;
gt = GGUF_TYPE_INT8; gt = GGUF_TYPE_INT8;
} else if (tt == typeid(int16_t)) { } else if (std::is_same<T, int16_t>::value) {
is_signed = true; is_signed = true;
gt = GGUF_TYPE_INT16; gt = GGUF_TYPE_INT16;
} else if (tt == typeid(int32_t)) { } else if (std::is_same<T, int32_t>::value) {
is_signed = true; is_signed = true;
gt = GGUF_TYPE_INT32; gt = GGUF_TYPE_INT32;
} else if (tt == typeid(int64_t)) { } else if (std::is_same<T, int64_t>::value) {
is_signed = true; is_signed = true;
gt = GGUF_TYPE_INT64; gt = GGUF_TYPE_INT64;
} else if (tt == typeid(float)) { } else if (std::is_same<T, float>::value) {
is_signed = true; is_signed = true;
gt = GGUF_TYPE_FLOAT32; gt = GGUF_TYPE_FLOAT32;
ot = LLAMA_KV_OVERRIDE_FLOAT; ot = LLAMA_KV_OVERRIDE_FLOAT;
} else if (tt == typeid(double)) { } else if (std::is_same<T, double>::value) {
is_signed = true; is_signed = true;
gt = GGUF_TYPE_FLOAT64; gt = GGUF_TYPE_FLOAT64;
ot = LLAMA_KV_OVERRIDE_FLOAT; ot = LLAMA_KV_OVERRIDE_FLOAT;
} else if (tt == typeid(bool)) { } else if (std::is_same<T, bool>::value) {
gt = GGUF_TYPE_BOOL; gt = GGUF_TYPE_BOOL;
ot = LLAMA_KV_OVERRIDE_BOOL; ot = LLAMA_KV_OVERRIDE_BOOL;
} else if (tt == typeid(std::string)) { } else if (std::is_same<T, std::string>::value) {
can_override = false; can_override = false;
gt = GGUF_TYPE_STRING; gt = GGUF_TYPE_STRING;
} else { } else {
throw std::runtime_error(format("request for key '%s' with unknown result type: %s", key.c_str(), tt.name())); throw std::runtime_error(format("request for key '%s' with unknown result type: %s", key.c_str(), typeid(T).name()));
} }
if (can_override) { if (can_override) {
@ -1921,13 +1919,16 @@ struct llama_model_loader {
struct llama_model_kv_override & po = it->second; struct llama_model_kv_override & po = it->second;
if (po.tag != ot) { if (po.tag != ot) {
// Bad type // Bad type
// FIXME: Error reporting
} else if (ot == LLAMA_KV_OVERRIDE_INT && po.int_value < 0 && !is_signed) { } else if (ot == LLAMA_KV_OVERRIDE_INT && po.int_value < 0 && !is_signed) {
// Out of range // Out of range
// FIXME: Error reporting
} else { } else {
// FIXME: Possible informational output
switch (po.tag) { switch (po.tag) {
case LLAMA_KV_OVERRIDE_INT: gk_set_lit(po.int_value, result); break; case LLAMA_KV_OVERRIDE_INT: gk_set_lit(po.int_value, result); break;
case LLAMA_KV_OVERRIDE_FLOAT: gk_set_lit(po.float_value, result); break; case LLAMA_KV_OVERRIDE_FLOAT: gk_set_lit(po.float_value, result); break;
case LLAMA_KV_OVERRIDE_BOOL: gk_set_lit(po.bool_value, result); break; case LLAMA_KV_OVERRIDE_BOOL: gk_set_lit(po.bool_value, result); break;
default: GGML_ASSERT(false && "Impossible: Unhandled override tag type"); default: GGML_ASSERT(false && "Impossible: Unhandled override tag type");
} }
return true; return true;