Refactor... basically everything!

This commit is contained in:
KerfuffleV2 2023-11-16 19:08:48 -07:00
parent 69be5c3d6d
commit 8c9f776952

347
llama.cpp
View file

@ -75,6 +75,7 @@
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <thread> #include <thread>
#include <type_traits>
#include <unordered_map> #include <unordered_map>
#if defined(_MSC_VER) #if defined(_MSC_VER)
@ -1682,6 +1683,161 @@ static std::string llama_format_tensor_shape(const struct ggml_tensor * t) {
return buf; return buf;
} }
namespace GGUFMeta {
template <typename T, gguf_type gt_, T (*gfun)(const gguf_context *, const int)>
struct GKV_Base_Type {
static constexpr gguf_type gt = gt_;
static T getter(const gguf_context * ctx, const int kid) {
return gfun(ctx, kid);
}
};
template<typename T> struct GKV_Base;
template<> struct GKV_Base<bool >: GKV_Base_Type<bool, GGUF_TYPE_BOOL, gguf_get_val_bool> {};
template<> struct GKV_Base<uint8_t >: GKV_Base_Type<uint8_t, GGUF_TYPE_UINT8, gguf_get_val_u8 > {};
template<> struct GKV_Base<uint16_t >: GKV_Base_Type<uint16_t, GGUF_TYPE_UINT16, gguf_get_val_u16 > {};
template<> struct GKV_Base<uint32_t >: GKV_Base_Type<uint32_t, GGUF_TYPE_UINT32, gguf_get_val_u32 > {};
template<> struct GKV_Base<uint64_t >: GKV_Base_Type<uint64_t, GGUF_TYPE_UINT64, gguf_get_val_u64 > {};
template<> struct GKV_Base<int8_t >: GKV_Base_Type<int8_t, GGUF_TYPE_INT8, gguf_get_val_i8 > {};
template<> struct GKV_Base<int16_t >: GKV_Base_Type<int16_t, GGUF_TYPE_INT16, gguf_get_val_i16 > {};
template<> struct GKV_Base<int32_t >: GKV_Base_Type<int32_t, GGUF_TYPE_INT32, gguf_get_val_i32 > {};
template<> struct GKV_Base<int64_t >: GKV_Base_Type<int64_t, GGUF_TYPE_INT64, gguf_get_val_i64 > {};
template<> struct GKV_Base<float >: GKV_Base_Type<float, GGUF_TYPE_FLOAT32, gguf_get_val_f32 > {};
template<> struct GKV_Base<double >: GKV_Base_Type<double, GGUF_TYPE_FLOAT64, gguf_get_val_f64 > {};
template<> struct GKV_Base<const char *>: GKV_Base_Type<const char *, GGUF_TYPE_STRING, gguf_get_val_str > {};
struct GetArrayLen{int value;};
template<> struct GKV_Base<GetArrayLen> {
public:
static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
static GetArrayLen getter(const gguf_context *ctx, const int k) {
return GetArrayLen{gguf_get_arr_n(ctx, k)};
}
};
struct ArrayInfo{
const gguf_type gt;
const size_t length;
const void * data;
};
template<> struct GKV_Base<ArrayInfo> {
public:
static constexpr gguf_type gt = GGUF_TYPE_ARRAY;
static ArrayInfo getter(const gguf_context *ctx, const int k) {
return ArrayInfo {
gguf_get_arr_type(ctx, k),
size_t(gguf_get_arr_n(ctx, k)),
gguf_get_arr_data(ctx, k),
};
}
};
template<typename T>
class GKV: public GKV_Base<T> {
GKV() = delete;
public:
static T get_kv(const gguf_context * ctx, const int k) {
const enum gguf_type kt = gguf_get_kv_type(ctx, k);
if (kt != GKV::gt) {
throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
gguf_get_key(ctx, k), gguf_type_name(kt), gguf_type_name(GKV::gt)));
}
return GKV::getter(ctx, k);
}
// This can't be uncommented.
// template<typename OT> static bool try_override(OT & target, const struct llama_model_kv_override *override) = delete;
template<typename OT>
static typename std::enable_if<std::is_same<OT, bool>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override *override) {
if (!override) {
return false;
}
if (override->tag != LLAMA_KV_OVERRIDE_BOOL) {
return false;
}
target = override->bool_value;
return true;
}
template<typename OT>
static typename std::enable_if<!std::is_same<OT, bool>::value && std::is_integral<OT>::value, bool>::type
try_override(OT & target, const struct llama_model_kv_override *override) {
if (!override) {
return false;
}
if (override->tag != LLAMA_KV_OVERRIDE_INT) {
return false;
}
if (override->int_value < 0 && !std::is_signed<T>::value) {
return false;
}
target = override->int_value;
return true;
}
template<typename OT>
static typename std::enable_if<std::is_floating_point<OT>::value, bool>::type
try_override(T & target, const struct llama_model_kv_override *override) {
if (!override) {
return false;
}
if (override->tag != LLAMA_KV_OVERRIDE_FLOAT) {
return false;
}
target = override->float_value;
return true;
}
static bool set(const gguf_context * ctx, const int k, T & target, const struct llama_model_kv_override *override = nullptr) {
if (try_override<T>(target, override)) {
return true;
}
target = get_kv(ctx, k);
return true;
}
template <typename TT>
static bool set(const gguf_context * ctx, const char * key, TT & target, const struct llama_model_kv_override *override = nullptr) {
const int kid = gguf_find_key(ctx, key);
if (kid < 0) {
return false;
}
return GKV<TT>::set(ctx, kid, target, override);
}
template <typename TT>
static bool set(const gguf_context * ctx, const std::string & key, TT & target, const struct llama_model_kv_override *override = nullptr) {
return GKV<TT>::set(ctx, key.c_str(), target, override);
}
};
template<>
class GKV<std::string>: public GKV_Base<const char *> {
using BT = const char *;
public:
static bool set(const gguf_context * ctx, const int k, std::string & target, const struct llama_model_kv_override *override = nullptr) {
(void)override;
target = std::string(GKV<BT>::get_kv(ctx, k));
return true;
}
static bool set(const gguf_context * ctx, const char * key, std::string & target, const struct llama_model_kv_override *override = nullptr) {
return GKV<BT>::set(ctx, key, target, override);
}
static bool set(const gguf_context * ctx, const std::string & key, std::string & target, const struct llama_model_kv_override *override = nullptr) {
return GKV<BT>::set(ctx, key, target, override);
}
};
}
struct llama_model_loader { struct llama_model_loader {
int n_kv = 0; int n_kv = 0;
int n_tensors = 0; int n_tensors = 0;
@ -1826,117 +1982,10 @@ struct llama_model_loader {
} }
} }
private: template<typename T> typename std::enable_if<std::is_integral<T>::value, bool>::type
template <typename T> struct gk_get_arrlen { T & output; }; get_arr_n(const std::string & key, T & result, const bool required = false) {
template <typename TI, typename TO> struct gk_set_literal { TI & input; TO & output; };
template <typename T>
void gk_set(int kid, T & result) {
(void)result;
throw std::runtime_error(format("request for key id %d with unhandled result type: %s", kid, typeid(T).name()));
}
void gk_set(const int k, uint8_t & r) { r = gguf_get_val_u8 (ctx_gguf, k); }
void gk_set(const int k, uint16_t & r) { r = gguf_get_val_u16 (ctx_gguf, k); }
void gk_set(const int k, uint32_t & r) { r = gguf_get_val_u32 (ctx_gguf, k); }
void gk_set(const int k, uint64_t & r) { r = gguf_get_val_u64 (ctx_gguf, k); }
void gk_set(const int k, int8_t & r) { r = gguf_get_val_i8 (ctx_gguf, k); }
void gk_set(const int k, int16_t & r) { r = gguf_get_val_i16 (ctx_gguf, k); }
void gk_set(const int k, int32_t & r) { r = gguf_get_val_i32 (ctx_gguf, k); }
void gk_set(const int k, int64_t & r) { r = gguf_get_val_i64 (ctx_gguf, k); }
void gk_set(const int k, float & r) { r = gguf_get_val_f32 (ctx_gguf, k); }
void gk_set(const int k, double & r) { r = gguf_get_val_f64 (ctx_gguf, k); }
void gk_set(const int k, bool & r) { r = gguf_get_val_bool(ctx_gguf, k); }
void gk_set(const int k, std::string & r) { r = std::string(gguf_get_val_str(ctx_gguf, k)); }
template<typename T>
typename std::enable_if<std::is_integral<T>::value, void>::type
gk_set(const int k, struct gk_get_arrlen<T> r) { r.output = gguf_get_arr_n(ctx_gguf, k); }
template<typename TI, typename TO>
void gk_set_lit(const TI i, TO o) {
(void)i; (void)o;
throw std::runtime_error(format("gk_set_lit can't handle types: in=%s, out=%s",
typeid(TI).name(), typeid(TO).name()));
}
template<typename T>
typename std::enable_if<std::is_integral<T>::value, void>::type
gk_set_lit(const int64_t i, T & o) { o = T(i); }
template<typename T>
typename std::enable_if<std::is_floating_point<T>::value, void>::type
gk_set_lit(const double i, T & o) { o = T(i); }
template<typename T>
void gk_set_lit(const T & i, T & o) { o = i; }
public:
template<typename T>
bool get_key(const std::string & key, T & result, const bool required = false) {
enum gguf_type gt = GGUF_TYPE_COUNT;
enum llama_model_kv_override_type ot = LLAMA_KV_OVERRIDE_INT;
bool is_signed = false, can_override = true;
if (std::is_same<T, uint8_t>::value) {
gt = GGUF_TYPE_UINT8;
} else if (std::is_same<T, uint16_t>::value) {
gt = GGUF_TYPE_UINT16;
} else if (std::is_same<T, uint32_t>::value) {
gt = GGUF_TYPE_UINT32;
} else if (std::is_same<T, uint64_t>::value) {
gt = GGUF_TYPE_UINT64;
} else if (std::is_same<T, int8_t>::value) {
is_signed = true;
gt = GGUF_TYPE_INT8;
} else if (std::is_same<T, int16_t>::value) {
is_signed = true;
gt = GGUF_TYPE_INT16;
} else if (std::is_same<T, int32_t>::value) {
is_signed = true;
gt = GGUF_TYPE_INT32;
} else if (std::is_same<T, int64_t>::value) {
is_signed = true;
gt = GGUF_TYPE_INT64;
} else if (std::is_same<T, float>::value) {
is_signed = true;
gt = GGUF_TYPE_FLOAT32;
ot = LLAMA_KV_OVERRIDE_FLOAT;
} else if (std::is_same<T, double>::value) {
is_signed = true;
gt = GGUF_TYPE_FLOAT64;
ot = LLAMA_KV_OVERRIDE_FLOAT;
} else if (std::is_same<T, bool>::value) {
gt = GGUF_TYPE_BOOL;
ot = LLAMA_KV_OVERRIDE_BOOL;
} else if (std::is_same<T, std::string>::value) {
can_override = false;
gt = GGUF_TYPE_STRING;
} else {
throw std::runtime_error(format("request for key '%s' with unknown result type: %s", key.c_str(), typeid(T).name()));
}
if (can_override) {
auto it = kv_overrides.find(key);
if (it != kv_overrides.end()) {
struct llama_model_kv_override & po = it->second;
if (po.tag != ot) {
// Bad type
// FIXME: Error reporting
} else if (ot == LLAMA_KV_OVERRIDE_INT && po.int_value < 0 && !is_signed) {
// Out of range
// FIXME: Error reporting
} else {
// FIXME: Possible informational output
switch (po.tag) {
case LLAMA_KV_OVERRIDE_INT: gk_set_lit(po.int_value, result); break;
case LLAMA_KV_OVERRIDE_FLOAT: gk_set_lit(po.float_value, result); break;
case LLAMA_KV_OVERRIDE_BOOL: gk_set_lit(po.bool_value, result); break;
default: GGML_ASSERT(false && "Impossible: Unhandled override tag type");
}
return true;
}
}
}
const int kid = gguf_find_key(ctx_gguf, key.c_str()); const int kid = gguf_find_key(ctx_gguf, key.c_str());
if (kid < 0) { if (kid < 0) {
if (required) { if (required) {
throw std::runtime_error(format("key not found in model: %s", key.c_str())); throw std::runtime_error(format("key not found in model: %s", key.c_str()));
@ -1944,23 +1993,37 @@ struct llama_model_loader {
return false; return false;
} }
const enum gguf_type ktype = gguf_get_kv_type(ctx_gguf, kid); struct GGUFMeta::ArrayInfo arr_info =
if (ktype == GGUF_TYPE_ARRAY && ot == LLAMA_KV_OVERRIDE_INT) { GGUFMeta::GKV<GGUFMeta::ArrayInfo>::get_kv(ctx_gguf, kid);
gk_get_arrlen<T> arrlen = {result};
gk_set(kid, arrlen);
return true;
}
if (ktype != gt) {
throw std::runtime_error(format("key %s has wrong type %s but expected type %s",
key.c_str(), gguf_type_name(ktype), gguf_type_name(gt)));
}
gk_set(kid, result);
result = arr_info.length;
return true; return true;
} }
template<typename T> typename std::enable_if<std::is_integral<T>::value, bool>::type
get_arr_n(const enum llm_kv kid, T & result, const bool required = true) {
return get_arr_n(llm_kv(kid), result, required);
}
template<typename T> template<typename T>
bool get_key(const enum llm_kv kid, T & result, const bool required = false) { bool get_key(const std::string & key, T & result, const bool required = true) {
auto it = kv_overrides.find(key);
const struct llama_model_kv_override * override =
it != kv_overrides.end() ? &it->second : nullptr;
const bool found = GGUFMeta::GKV<T>::set(ctx_gguf, key, result, override);
if (required && !found) {
throw std::runtime_error(format("key not found in model: %s", key.c_str()));
}
return found;
}
template<typename T>
bool get_key(const enum llm_kv kid, T & result, const bool required = true) {
return get_key(llm_kv(kid), result, required); return get_key(llm_kv(kid), result, required);
} }
@ -2222,12 +2285,12 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_GENERAL_NAME, model.name, false); ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
// get hparams kv // get hparams kv
ml.get_key(LLM_KV_TOKENIZER_LIST, hparams.n_vocab, true); ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train, true); ml.get_key (LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train);
ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd, true); ml.get_key (LLM_KV_EMBEDDING_LENGTH, hparams.n_embd);
ml.get_key(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff, true); ml.get_key (LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff);
ml.get_key(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head, true); ml.get_key (LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head);
ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer, true); ml.get_key (LLM_KV_BLOCK_COUNT, hparams.n_layer);
// n_head_kv is optional, default to n_head // n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head; hparams.n_head_kv = hparams.n_head;
@ -2251,8 +2314,8 @@ static void llm_load_hparams(
// rope_freq_scale (inverse of the kv) is optional // rope_freq_scale (inverse of the kv) is optional
float ropescale = 0.0f; float ropescale = 0.0f;
ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false); if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) {
if (ropescale == 0.0f) { // try the old key name // try the old key name
ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false); ml.get_key(LLM_KV_ROPE_SCALE_LINEAR, ropescale, false);
} }
hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale;
@ -2276,7 +2339,7 @@ static void llm_load_hparams(
switch (model.arch) { switch (model.arch) {
case LLM_ARCH_LLAMA: case LLM_ARCH_LLAMA:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 26: model.type = e_model::MODEL_3B; break; case 26: model.type = e_model::MODEL_3B; break;
@ -2290,7 +2353,7 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_FALCON: case LLM_ARCH_FALCON:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_7B; break; case 32: model.type = e_model::MODEL_7B; break;
@ -2300,7 +2363,7 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_BAICHUAN: case LLM_ARCH_BAICHUAN:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_7B; break; case 32: model.type = e_model::MODEL_7B; break;
case 40: model.type = e_model::MODEL_13B; break; case 40: model.type = e_model::MODEL_13B; break;
@ -2309,7 +2372,7 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_STARCODER: case LLM_ARCH_STARCODER:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1B; break; case 24: model.type = e_model::MODEL_1B; break;
case 36: model.type = e_model::MODEL_3B; break; case 36: model.type = e_model::MODEL_3B; break;
@ -2320,7 +2383,7 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_PERSIMMON: case LLM_ARCH_PERSIMMON:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 36: model.type = e_model::MODEL_8B; break; case 36: model.type = e_model::MODEL_8B; break;
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
@ -2328,7 +2391,7 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_REFACT: case LLM_ARCH_REFACT:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_1B; break; case 32: model.type = e_model::MODEL_1B; break;
default: model.type = e_model::MODEL_UNKNOWN; default: model.type = e_model::MODEL_UNKNOWN;
@ -2336,7 +2399,7 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_BLOOM: case LLM_ARCH_BLOOM:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 24: model.type = e_model::MODEL_1B; break; case 24: model.type = e_model::MODEL_1B; break;
@ -2351,9 +2414,9 @@ static void llm_load_hparams(
{ {
hparams.f_clamp_kqv = 0.0f; hparams.f_clamp_kqv = 0.0f;
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false); ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV, hparams.f_clamp_kqv, false);
ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias, true); ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_7B; break; case 32: model.type = e_model::MODEL_7B; break;
@ -2363,7 +2426,7 @@ static void llm_load_hparams(
} break; } break;
case LLM_ARCH_STABLELM: case LLM_ARCH_STABLELM:
{ {
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps, true); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 32: model.type = e_model::MODEL_3B; break; case 32: model.type = e_model::MODEL_3B; break;
@ -2411,7 +2474,7 @@ static void llm_load_vocab(
{ {
std::string tokenizer_name; std::string tokenizer_name;
ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name, true); ml.get_key(LLM_KV_TOKENIZER_MODEL, tokenizer_name);
if (tokenizer_name == "llama") { if (tokenizer_name == "llama") {
vocab.type = LLAMA_VOCAB_TYPE_SPM; vocab.type = LLAMA_VOCAB_TYPE_SPM;