diff --git a/llama.cpp b/llama.cpp index 85fa4ad76..22ff38012 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2681,10 +2681,10 @@ static void llm_load_hparams( // gpt-j n_rot = rotary_dim } - hparams.n_embd_head_k = hparams.n_embd / hparams.n_head_kv; + hparams.n_embd_head_k = hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); - hparams.n_embd_head_v = hparams.n_embd / hparams.n_head_kv; + hparams.n_embd_head_v = hparams.n_embd / hparams.n_head; ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); // arch-specific KVs