remove n_rot hparam, as it must always be hparam.n_embd_head()
This commit is contained in:
parent
56a03faf5f
commit
1dbd6bc3d5
1 changed files with 2 additions and 9 deletions
|
@ -29,7 +29,6 @@ struct my_llama_hparams {
|
||||||
uint32_t n_head = 32;
|
uint32_t n_head = 32;
|
||||||
uint32_t n_head_kv = 32;
|
uint32_t n_head_kv = 32;
|
||||||
uint32_t n_layer = 32;
|
uint32_t n_layer = 32;
|
||||||
uint32_t n_rot = 64;
|
|
||||||
|
|
||||||
uint32_t n_gqa() const {
|
uint32_t n_gqa() const {
|
||||||
return n_head/n_head_kv;
|
return n_head/n_head_kv;
|
||||||
|
@ -203,7 +202,6 @@ static void print_params(struct my_llama_hparams * params) {
|
||||||
printf("%s: n_ff: %u\n", __func__, params->n_ff);
|
printf("%s: n_ff: %u\n", __func__, params->n_ff);
|
||||||
printf("%s: n_head: %u\n", __func__, params->n_head);
|
printf("%s: n_head: %u\n", __func__, params->n_head);
|
||||||
printf("%s: n_layer: %u\n", __func__, params->n_layer);
|
printf("%s: n_layer: %u\n", __func__, params->n_layer);
|
||||||
printf("%s: n_rot: %u\n", __func__, params->n_rot);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print_lora_params(struct my_llama_lora_hparams * params) {
|
static void print_lora_params(struct my_llama_lora_hparams * params) {
|
||||||
|
@ -247,7 +245,6 @@ static void init_model(struct llama_model * input, struct my_llama_model * model
|
||||||
hparams.n_head = llama_model_n_head(input);
|
hparams.n_head = llama_model_n_head(input);
|
||||||
hparams.n_head_kv = llama_model_n_head_kv(input);
|
hparams.n_head_kv = llama_model_n_head_kv(input);
|
||||||
hparams.n_layer = llama_model_n_layer(input);
|
hparams.n_layer = llama_model_n_layer(input);
|
||||||
hparams.n_rot = llama_model_n_rot(input);
|
|
||||||
|
|
||||||
model->tok_embeddings = llama_get_model_tensor(input, tn(LLM_TENSOR_TOKEN_EMBD));
|
model->tok_embeddings = llama_get_model_tensor(input, tn(LLM_TENSOR_TOKEN_EMBD));
|
||||||
model->norm = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT_NORM));
|
model->norm = llama_get_model_tensor(input, tn(LLM_TENSOR_OUTPUT_NORM));
|
||||||
|
@ -535,8 +532,8 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
||||||
const int n_layer = hparams.n_layer;
|
const int n_layer = hparams.n_layer;
|
||||||
const int n_head = hparams.n_head;
|
const int n_head = hparams.n_head;
|
||||||
const int n_head_kv = hparams.n_head_kv;
|
const int n_head_kv = hparams.n_head_kv;
|
||||||
const int n_rot = hparams.n_rot;
|
|
||||||
const int n_ff = hparams.n_ff;
|
const int n_ff = hparams.n_ff;
|
||||||
|
const int n_rot = hparams.n_embd_head();
|
||||||
const int n_embd_head = hparams.n_embd_head();
|
const int n_embd_head = hparams.n_embd_head();
|
||||||
const int n_embd_gqa = hparams.n_embd_gqa();
|
const int n_embd_gqa = hparams.n_embd_gqa();
|
||||||
const float rms_norm_eps = lora->hparams.f_norm_rms_eps;
|
const float rms_norm_eps = lora->hparams.f_norm_rms_eps;
|
||||||
|
@ -544,7 +541,6 @@ static struct ggml_tensor * llama_build_lora_finetune_graphs(
|
||||||
const float rope_freq_scale = lora->hparams.rope_freq_scale;
|
const float rope_freq_scale = lora->hparams.rope_freq_scale;
|
||||||
|
|
||||||
GGML_ASSERT((size_t) n_layer == lora->layers.size());
|
GGML_ASSERT((size_t) n_layer == lora->layers.size());
|
||||||
GGML_ASSERT(n_embd_head == n_rot);
|
|
||||||
|
|
||||||
auto set_name = [](struct ggml_tensor * t, const char * n) {
|
auto set_name = [](struct ggml_tensor * t, const char * n) {
|
||||||
ggml_set_name(t, n);
|
ggml_set_name(t, n);
|
||||||
|
@ -823,9 +819,6 @@ static void load_llama_lora_gguf(struct gguf_context * fctx, struct ggml_context
|
||||||
model->hparams.n_head_kv = model->hparams.n_head;
|
model->hparams.n_head_kv = model->hparams.n_head;
|
||||||
GGUF_GET_KEY(fctx, model->hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
|
GGUF_GET_KEY(fctx, model->hparams.n_head_kv, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV));
|
||||||
|
|
||||||
model->hparams.n_rot = model->hparams.n_embd / model->hparams.n_head;
|
|
||||||
GGUF_GET_KEY(fctx, model->hparams.n_rot, gguf_get_val_u32, GGUF_TYPE_UINT32, false, kv(LLM_KV_ROPE_DIMENSION_COUNT));
|
|
||||||
|
|
||||||
float rope_freq_scale = 1.0f;
|
float rope_freq_scale = 1.0f;
|
||||||
GGUF_GET_KEY(fctx, lora->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
GGUF_GET_KEY(fctx, lora->hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
|
||||||
GGUF_GET_KEY(fctx, lora->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
|
GGUF_GET_KEY(fctx, lora->hparams.rope_freq_base, gguf_get_val_f32, GGUF_TYPE_FLOAT32, false, kv(LLM_KV_ROPE_FREQ_BASE));
|
||||||
|
@ -899,7 +892,7 @@ static void save_llama_lora_gguf(struct gguf_context * fctx, struct my_llama_mod
|
||||||
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head);
|
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT), model->hparams.n_head);
|
||||||
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV), model->hparams.n_head_kv);
|
gguf_set_val_u32(fctx, kv(LLM_KV_ATTENTION_HEAD_COUNT_KV), model->hparams.n_head_kv);
|
||||||
gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer);
|
gguf_set_val_u32(fctx, kv(LLM_KV_BLOCK_COUNT), model->hparams.n_layer);
|
||||||
gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_rot);
|
gguf_set_val_u32(fctx, kv(LLM_KV_ROPE_DIMENSION_COUNT), model->hparams.n_embd_head());
|
||||||
gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), lora->hparams.f_norm_rms_eps);
|
gguf_set_val_f32(fctx, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS), lora->hparams.f_norm_rms_eps);
|
||||||
gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE), lora->hparams.rope_freq_base);
|
gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_FREQ_BASE), lora->hparams.rope_freq_base);
|
||||||
gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR), lora->hparams.rope_freq_scale);
|
gguf_set_val_f32(fctx, kv(LLM_KV_ROPE_SCALE_LINEAR), lora->hparams.rope_freq_scale);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue