llama : throw error on missing KV paris in model meta data
This commit is contained in:
parent
c1fe0aba72
commit
f634b292c9
2 changed files with 28 additions and 16 deletions
1
ggml.h
1
ggml.h
|
@ -1752,6 +1752,7 @@ extern "C" {
|
||||||
GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
|
GGML_API enum gguf_type gguf_get_kv_type (struct gguf_context * ctx, int i);
|
||||||
GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
|
GGML_API enum gguf_type gguf_get_arr_type(struct gguf_context * ctx, int i);
|
||||||
|
|
||||||
|
// results are undefined if the wrong type is used for the key
|
||||||
GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
|
GGML_API uint8_t gguf_get_val_u8 (struct gguf_context * ctx, int i);
|
||||||
GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);
|
GGML_API int8_t gguf_get_val_i8 (struct gguf_context * ctx, int i);
|
||||||
GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i);
|
GGML_API uint16_t gguf_get_val_u16 (struct gguf_context * ctx, int i);
|
||||||
|
|
|
@ -107,6 +107,7 @@
|
||||||
|
|
||||||
static void llama_log_internal(llama_log_level level, const char* format, ...);
|
static void llama_log_internal(llama_log_level level, const char* format, ...);
|
||||||
static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data);
|
static void llama_log_callback_default(llama_log_level level, const char * text, void * user_data);
|
||||||
|
|
||||||
#define LLAMA_LOG_INFO(...) llama_log_internal(LLAMA_LOG_LEVEL_INFO , __VA_ARGS__)
|
#define LLAMA_LOG_INFO(...) llama_log_internal(LLAMA_LOG_LEVEL_INFO , __VA_ARGS__)
|
||||||
#define LLAMA_LOG_WARN(...) llama_log_internal(LLAMA_LOG_LEVEL_WARN , __VA_ARGS__)
|
#define LLAMA_LOG_WARN(...) llama_log_internal(LLAMA_LOG_LEVEL_WARN , __VA_ARGS__)
|
||||||
#define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__)
|
#define LLAMA_LOG_ERROR(...) llama_log_internal(LLAMA_LOG_LEVEL_ERROR, __VA_ARGS__)
|
||||||
|
@ -1274,24 +1275,34 @@ static void llama_model_load_internal(
|
||||||
{
|
{
|
||||||
struct gguf_context * ctx = ml->ctx_gguf;
|
struct gguf_context * ctx = ml->ctx_gguf;
|
||||||
|
|
||||||
hparams.n_vocab = gguf_get_arr_n (ctx, gguf_find_key(ctx, "tokenizer.ggml.tokens"));
|
#define GGUF_GET(dst, func, type, req, key) \
|
||||||
hparams.n_ctx = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.context_length"));
|
{ \
|
||||||
hparams.n_embd = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.embedding_length"));
|
const int kid = gguf_find_key(ctx, key); \
|
||||||
hparams.n_ff = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.feed_forward_length"));
|
if (kid >= 0) { \
|
||||||
hparams.n_head = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.attention.head_count"));
|
enum gguf_type ktype = gguf_get_kv_type(ctx, kid); \
|
||||||
hparams.n_layer = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.block_count"));
|
if (ktype != (type)) { \
|
||||||
hparams.n_rot = gguf_get_val_u32(ctx, gguf_find_key(ctx, "llama.rope.dimension_count"));
|
throw std::runtime_error(format("key %s has wrong type: %d", key, ktype)); \
|
||||||
hparams.f_norm_rms_eps = gguf_get_val_f32(ctx, gguf_find_key(ctx, "llama.attention.layer_norm_rms_epsilon"));
|
} \
|
||||||
|
(dst) = func(ctx, kid); \
|
||||||
// n_head_kv default to n_head
|
} else if (req) { \
|
||||||
hparams.n_head_kv = hparams.n_head;
|
throw std::runtime_error(format("key not found in model: %s", key)); \
|
||||||
{
|
} \
|
||||||
const int idx = gguf_find_key(ctx, "llama.attention.head_count_kv");
|
|
||||||
if (idx >= 0) {
|
|
||||||
hparams.n_head_kv = gguf_get_val_u32(ctx, idx);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGUF_GET(hparams.n_vocab, gguf_get_arr_n, GGUF_TYPE_ARRAY, true, "tokenizer.ggml.tokens");
|
||||||
|
GGUF_GET(hparams.n_ctx, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.context_length");
|
||||||
|
GGUF_GET(hparams.n_embd, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.embedding_length");
|
||||||
|
GGUF_GET(hparams.n_ff, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.feed_forward_length");
|
||||||
|
GGUF_GET(hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.attention.head_count");
|
||||||
|
GGUF_GET(hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.block_count");
|
||||||
|
GGUF_GET(hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, true, "llama.rope.dimension_count");
|
||||||
|
GGUF_GET(hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, "llama.attention.layer_norm_rms_epsilon");
|
||||||
|
|
||||||
|
// n_head_kv is optional, default to n_head
|
||||||
|
hparams.n_head_kv = hparams.n_head;
|
||||||
|
GGUF_GET(hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, "llama.attention.head_count_kv");
|
||||||
|
#undef GGUF_GET
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 26: model.type = e_model::MODEL_3B; break;
|
case 26: model.type = e_model::MODEL_3B; break;
|
||||||
case 32: model.type = e_model::MODEL_7B; break;
|
case 32: model.type = e_model::MODEL_7B; break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue