save and load head_count_kv in lora checkpoints
This commit is contained in:
parent
7aa9ea7f20
commit
48d3509190
1 changed files with 6 additions and 1 deletions
|
@ -174,6 +174,7 @@ static const char * LLM_KV_EMBEDDING_LENGTH = "%s.embedding_length";
|
||||||
static const char * LLM_KV_BLOCK_COUNT = "%s.block_count";
|
static const char * LLM_KV_BLOCK_COUNT = "%s.block_count";
|
||||||
static const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length";
|
static const char * LLM_KV_FEED_FORWARD_LENGTH = "%s.feed_forward_length";
|
||||||
static const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count";
|
static const char * LLM_KV_ATTENTION_HEAD_COUNT = "%s.attention.head_count";
|
||||||
|
static const char * LLM_KV_ATTENTION_HEAD_COUNT_KV = "%s.attention.head_count_kv";
|
||||||
static const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
|
static const char * LLM_KV_ATTENTION_LAYERNORM_RMS_EPS = "%s.attention.layer_norm_rms_epsilon";
|
||||||
static const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count";
|
static const char * LLM_KV_ROPE_DIMENSION_COUNT = "%s.rope.dimension_count";
|
||||||
static const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp
|
static const char * LLM_KV_ROPE_FREQ_BASE = "%s.rope.freq_base"; // TODO load in llama.cpp
|
||||||
|
@ -865,8 +866,11 @@ static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context
|
||||||
GGUF_GET_KEY(fctx, model->hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
|
GGUF_GET_KEY(fctx, model->hparams.n_head, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_ATTENTION_HEAD_COUNT));
|
||||||
GGUF_GET_KEY(fctx, model->hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
|
GGUF_GET_KEY(fctx, model->hparams.n_layer, gguf_get_val_u32, GGUF_TYPE_UINT32, true, kv(LLM_KV_BLOCK_COUNT));
|
||||||
|
|
||||||
|
model->hparams.n_head_kv = model->hparams.n_head;
|
||||||
|
GGUF_GET_KEY(fctx, model->hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
|
||||||
|
|
||||||
model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head;
|
model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head;
|
||||||
GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
|
GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
|
||||||
|
|
||||||
float rope_freq_scale = 1.0f;
|
float rope_freq_scale = 1.0f;
|
||||||
GGUF_GET_KEY(fctx, lora->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
GGUF_GET_KEY(fctx, lora->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
||||||
|
@ -939,6 +943,7 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod
|
||||||
gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd);
|
gguf_set_val_u32(fctx, kv(LLM_KV_EMBEDDING_LENGTH), model->hparams.n_embd);
|
||||||
gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff);
|
gguf_set_val_u32(fctx, kv(LLM_KV_FEED_FORWARD_LENGTH), model->hparams.n_ff);
|
||||||
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head);
|
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head);
|
||||||
|
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV), model->hparams.n_head_kv);
|
||||||
gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer);
|
gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer);
|
||||||
gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_rot);
|
gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_rot);
|
||||||
gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), lora->hparams.f_norm_rms_eps);
|
gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), lora->hparams.f_norm_rms_eps);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue