diff --git a/ggml.h b/ggml.h index 8a1661cfb..48ce71ecd 100644 --- a/ggml.h +++ b/ggml.h @@ -1752,6 +1752,7 @@ extern "C" { GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i); GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i); + // results are undefined if the wrong type is used for the key GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i); GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i); GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i); diff --git a/gguf-llama.cpp b/gguf-llama.cpp index cebe53d10..ec64ef8dc 100644 --- a/gguf-llama.cpp +++ b/gguf-llama.cpp @@ -107,6 +107,7 @@ static void llama_log_internal(llama_log_level level, const char* format, ...); static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data); + #define LLAMA_LOG_INFO(...) llama_log_internal(LLAMA_LOG_LEVEL_INFO , __VA_ARGS__) #define LLAMA_LOG_WARN(...) llama_log_internal(LLAMA_LOG_LEVEL_WARN , __VA_ARGS__) #define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__) @@ -1274,24 +1275,34 @@ static void llama_model_load_internal( { struct gguf_context * ctx = ml->ctx_gguf; - hparams.n_vocab = gguf_get_arr_n (ctx, gguf_find_key(ctx, "tokenizer.ggml.tokens")); - hparams.n_ctx = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.context_length")); - hparams.n_embd = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.embedding_length")); - hparams.n_ff = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.feed_forward_length")); - hparams.n_head = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.attention.head_count")); - hparams.n_layer = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.block_count")); - hparams.n_rot = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.rope.dimension_count")); - hparams.f_norm_rms_eps = gguf_get_val_f32(ctx, gguf_find_key(ctx, "llama.attention.layer_norm_rms_epsilon")); - - // n_head_kv default to n_head - hparams.n_head_kv = hparams.n_head; - { - const int idx = gguf_find_key(ctx, "llama.attention.head_count_kv"); - if (idx >= 0) { - hparams.n_head_kv = gguf_get_val_u32(ctx, idx); - } +#define GGUF_GET(dst, func, type, req, key) \ + { \ + const int kid = gguf_find_key(ctx, key); \ + if (kid >= 0) { \ + enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \ + if (ktype != (type)) { \ + throw std::runtime_error(format("key %s has wrong type: %d", key, ktype)); \ + } \ + (dst) = func(ctx, kid); \ + } else if (req) { \ + throw std::runtime_error(format("key not found in model: %s", key)); \ + } \ } + GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens"); + GGUF_GET(hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length"); + GGUF_GET(hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length"); + GGUF_GET(hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length"); + GGUF_GET(hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count"); + GGUF_GET(hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.block_count"); + GGUF_GET(hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.rope.dimension_count"); + GGUF_GET(hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, "llama.attention.layer_norm_rms_epsilon"); + + // n_head_kv is optional, default to n_head + hparams.n_head_kv = hparams.n_head; + GGUF_GET(hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv"); +#undef GGUF_GET + switch (hparams.n_layer) { case 26: model.type = e_model::MODEL_3B; break; case 32: model.type = e_model::MODEL_7B; break;