diff --git a/examples/finetune/finetune.cpp b/examples/finetune/finetune.cpp index 17b559a10..5c6fa639c 100644 --- a/examples/finetune/finetune.cpp +++ b/examples/finetune/finetune.cpp @@ -174,6 +174,7 @@ static const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length"; static const char * LLM_KV_BLOCK_COUNT = "%s.block_count"; static const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length"; static const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count"; +static const char * LLM_KV_ATTENTION_HEAD_COUNT_KV = "%s.attention.head_count_kv"; static const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon"; static const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count"; static const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp @@ -865,8 +866,11 @@ static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context GGUF_GET_KEY(fctx, model->hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT)); GGUF_GET_KEY(fctx, model->hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT)); + model->hparams.n_head_kv = model->hparams.n_head; + GGUF_GET_KEY(fctx, model->hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV)); + model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head; - GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); + GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT)); float rope_freq_scale = 1.0f; GGUF_GET_KEY(fctx, lora->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS)); @@ -939,6 +943,7 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd); gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff); gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head); + gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV), model->hparams.n_head_kv); gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer); gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_rot); gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), lora->hparams.f_norm_rms_eps);