This commit is contained in:
KerfuffleV2 2023-11-16 01:09:41 -07:00
parent 9d39deab8f
commit 69be5c3d6d

View file

@ -1834,26 +1834,25 @@ struct llama_model_loader {
(void)result;
throw std::runtime_error(format("request for key id %d with unhandled result type: %s", kid, typeid(T).name()));
}
void gk_set(int k, uint8_t & r) { r = gguf_get_val_u8 (ctx_gguf, k); }
void gk_set(int k, uint16_t & r) { r = gguf_get_val_u16 (ctx_gguf, k); }
void gk_set(int k, uint32_t & r) { r = gguf_get_val_u32 (ctx_gguf, k); }
void gk_set(int k, uint64_t & r) { r = gguf_get_val_u64 (ctx_gguf, k); }
void gk_set(int k, int8_t & r) { r = gguf_get_val_i8 (ctx_gguf, k); }
void gk_set(int k, int16_t & r) { r = gguf_get_val_i16 (ctx_gguf, k); }
void gk_set(int k, int32_t & r) { r = gguf_get_val_i32 (ctx_gguf, k); }
void gk_set(int k, int64_t & r) { r = gguf_get_val_i64 (ctx_gguf, k); }
void gk_set(int k, float & r) { r = gguf_get_val_f32 (ctx_gguf, k); }
void gk_set(int k, double & r) { r = gguf_get_val_f64 (ctx_gguf, k); }
void gk_set(int k, bool & r) { r = gguf_get_val_bool(ctx_gguf, k); }
void gk_set(int k, std::string & r) { r = std::string(gguf_get_val_str(ctx_gguf, k)); }
void gk_set(const int k, uint8_t & r) { r = gguf_get_val_u8 (ctx_gguf, k); }
void gk_set(const int k, uint16_t & r) { r = gguf_get_val_u16 (ctx_gguf, k); }
void gk_set(const int k, uint32_t & r) { r = gguf_get_val_u32 (ctx_gguf, k); }
void gk_set(const int k, uint64_t & r) { r = gguf_get_val_u64 (ctx_gguf, k); }
void gk_set(const int k, int8_t & r) { r = gguf_get_val_i8 (ctx_gguf, k); }
void gk_set(const int k, int16_t & r) { r = gguf_get_val_i16 (ctx_gguf, k); }
void gk_set(const int k, int32_t & r) { r = gguf_get_val_i32 (ctx_gguf, k); }
void gk_set(const int k, int64_t & r) { r = gguf_get_val_i64 (ctx_gguf, k); }
void gk_set(const int k, float & r) { r = gguf_get_val_f32 (ctx_gguf, k); }
void gk_set(const int k, double & r) { r = gguf_get_val_f64 (ctx_gguf, k); }
void gk_set(const int k, bool & r) { r = gguf_get_val_bool(ctx_gguf, k); }
void gk_set(const int k, std::string & r) { r = std::string(gguf_get_val_str(ctx_gguf, k)); }
template<typename T>
typename std::enable_if<std::is_integral<T>::value, void>::type
gk_set(int k, struct gk_get_arrlen<T> & r) { r.output = gguf_get_arr_n(ctx_gguf, k); }
gk_set(const int k, struct gk_get_arrlen<T> r) { r.output = gguf_get_arr_n(ctx_gguf, k); }
template<typename TI, typename TO>
void gk_set_lit(TI i, TO o) {
void gk_set_lit(const TI i, TO o) {
(void)i; (void)o;
throw std::runtime_error(format("gk_set_lit can't handle types: in=%s, out=%s",
typeid(TI).name(), typeid(TO).name()));
@ -1861,11 +1860,11 @@ struct llama_model_loader {
template<typename T>
typename std::enable_if<std::is_integral<T>::value, void>::type
gk_set_lit(const int64_t & i, T & o) { o = T(i); }
gk_set_lit(const int64_t i, T & o) { o = T(i); }
template<typename T>
typename std::enable_if<std::is_floating_point<T>::value, void>::type
gk_set_lit(const double & i, T & o) { o = T(i); }
gk_set_lit(const double i, T & o) { o = T(i); }
template<typename T>
void gk_set_lit(const T & i, T & o) { o = i; }
@ -1873,46 +1872,45 @@ struct llama_model_loader {
public:
template<typename T>
bool get_key(const std::string & key, T & result, const bool required = false) {
const auto & tt = typeid(T);
enum gguf_type gt = GGUF_TYPE_COUNT;
enum llama_model_kv_override_type ot = LLAMA_KV_OVERRIDE_INT;
bool is_signed = false, can_override = true;
if (tt == typeid(uint8_t)) {
if (std::is_same<T, uint8_t>::value) {
gt = GGUF_TYPE_UINT8;
} else if (tt == typeid(uint16_t)) {
} else if (std::is_same<T, uint16_t>::value) {
gt = GGUF_TYPE_UINT16;
} else if (tt == typeid(uint32_t)) {
} else if (std::is_same<T, uint32_t>::value) {
gt = GGUF_TYPE_UINT32;
} else if (tt == typeid(uint64_t)) {
} else if (std::is_same<T, uint64_t>::value) {
gt = GGUF_TYPE_UINT64;
} else if (tt == typeid(int8_t)) {
} else if (std::is_same<T, int8_t>::value) {
is_signed = true;
gt = GGUF_TYPE_INT8;
} else if (tt == typeid(int16_t)) {
} else if (std::is_same<T, int16_t>::value) {
is_signed = true;
gt = GGUF_TYPE_INT16;
} else if (tt == typeid(int32_t)) {
} else if (std::is_same<T, int32_t>::value) {
is_signed = true;
gt = GGUF_TYPE_INT32;
} else if (tt == typeid(int64_t)) {
} else if (std::is_same<T, int64_t>::value) {
is_signed = true;
gt = GGUF_TYPE_INT64;
} else if (tt == typeid(float)) {
} else if (std::is_same<T, float>::value) {
is_signed = true;
gt = GGUF_TYPE_FLOAT32;
ot = LLAMA_KV_OVERRIDE_FLOAT;
} else if (tt == typeid(double)) {
} else if (std::is_same<T, double>::value) {
is_signed = true;
gt = GGUF_TYPE_FLOAT64;
ot = LLAMA_KV_OVERRIDE_FLOAT;
} else if (tt == typeid(bool)) {
} else if (std::is_same<T, bool>::value) {
gt = GGUF_TYPE_BOOL;
ot = LLAMA_KV_OVERRIDE_BOOL;
} else if (tt == typeid(std::string)) {
} else if (std::is_same<T, std::string>::value) {
can_override = false;
gt = GGUF_TYPE_STRING;
} else {
throw std::runtime_error(format("request for key '%s' with unknown result type: %s", key.c_str(), tt.name()));
throw std::runtime_error(format("request for key '%s' with unknown result type: %s", key.c_str(), typeid(T).name()));
}
if (can_override) {
@ -1921,13 +1919,16 @@ struct llama_model_loader {
struct llama_model_kv_override & po = it->second;
if (po.tag != ot) {
// Bad type
// FIXME: Error reporting
} else if (ot == LLAMA_KV_OVERRIDE_INT && po.int_value < 0 && !is_signed) {
// Out of range
// FIXME: Error reporting
} else {
// FIXME: Possible informational output
switch (po.tag) {
case LLAMA_KV_OVERRIDE_INT: gk_set_lit(po.int_value, result); break;
case LLAMA_KV_OVERRIDE_INT: gk_set_lit(po.int_value, result); break;
case LLAMA_KV_OVERRIDE_FLOAT: gk_set_lit(po.float_value, result); break;
case LLAMA_KV_OVERRIDE_BOOL: gk_set_lit(po.bool_value, result); break;
case LLAMA_KV_OVERRIDE_BOOL: gk_set_lit(po.bool_value, result); break;
default: GGML_ASSERT(false && "Impossible: Unhandled override tag type");
}
return true;