llama : throw error on missing KV paris in model meta data

This commit is contained in:
Georgi Gerganov 2023-08-16 13:44:35 +03:00
parent c1fe0aba72
commit f634b292c9
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 28 additions and 16 deletions

1
ggml.h
View file

@ -1752,6 +1752,7 @@ extern "C" {
GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i); GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
// results are undefined if the wrong type is used for the key
GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i); GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i); GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);
GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i); GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i);

View file

@ -107,6 +107,7 @@
static void llama_log_internal(llama_log_level level, const char* format, ...); static void llama_log_internal(llama_log_level level, const char* format, ...);
static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data); static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data);
#define LLAMA_LOG_INFO(...) llama_log_internal(LLAMA_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_INFO(...) llama_log_internal(LLAMA_LOG_LEVEL_INFO , __VA_ARGS__)
#define LLAMA_LOG_WARN(...) llama_log_internal(LLAMA_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(LLAMA_LOG_LEVEL_WARN , __VA_ARGS__)
#define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__)
@ -1274,24 +1275,34 @@ static void llama_model_load_internal(
{ {
struct gguf_context * ctx = ml->ctx_gguf; struct gguf_context * ctx = ml->ctx_gguf;
hparams.n_vocab = gguf_get_arr_n (ctx, gguf_find_key(ctx, "tokenizer.ggml.tokens")); #define GGUF_GET(dst, func, type, req, key) \
hparams.n_ctx = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.context_length")); { \
hparams.n_embd = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.embedding_length")); const int kid = gguf_find_key(ctx, key); \
hparams.n_ff = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.feed_forward_length")); if (kid >= 0) { \
hparams.n_head = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.attention.head_count")); enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
hparams.n_layer = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.block_count")); if (ktype != (type)) { \
hparams.n_rot = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.rope.dimension_count")); throw std::runtime_error(format("key %s has wrong type: %d", key, ktype)); \
hparams.f_norm_rms_eps = gguf_get_val_f32(ctx, gguf_find_key(ctx, "llama.attention.layer_norm_rms_epsilon")); } \
(dst) = func(ctx, kid); \
// n_head_kv default to n_head } else if (req) { \
hparams.n_head_kv = hparams.n_head; throw std::runtime_error(format("key not found in model: %s", key)); \
{ } \
const int idx = gguf_find_key(ctx, "llama.attention.head_count_kv");
if (idx >= 0) {
hparams.n_head_kv = gguf_get_val_u32(ctx, idx);
}
} }
GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens");
GGUF_GET(hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length");
GGUF_GET(hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length");
GGUF_GET(hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length");
GGUF_GET(hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count");
GGUF_GET(hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.block_count");
GGUF_GET(hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.rope.dimension_count");
GGUF_GET(hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, "llama.attention.layer_norm_rms_epsilon");
// n_head_kv is optional, default to n_head
hparams.n_head_kv = hparams.n_head;
GGUF_GET(hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv");
#undef GGUF_GET
switch (hparams.n_layer) { switch (hparams.n_layer) {
case 26: model.type = e_model::MODEL_3B; break; case 26: model.type = e_model::MODEL_3B; break;
case 32: model.type = e_model::MODEL_7B; break; case 32: model.type = e_model::MODEL_7B; break;