llama : fix n_rot default (#8348)

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-07-07 14:59:02 +03:00 committed by Neo Zhang
parent 78706ed9a8
commit 155ec5bf82

View file

@ -4626,16 +4626,6 @@ static void llm_load_hparams(
// non-transformer models do not have attention heads // non-transformer models do not have attention heads
if (hparams.n_head() > 0) { if (hparams.n_head() > 0) {
// sanity check for n_rot (optional)
hparams.n_rot = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
}
}
// gpt-neox n_rot = rotary_pct * (n_embd / n_head) // gpt-neox n_rot = rotary_pct * (n_embd / n_head)
// gpt-j n_rot = rotary_dim // gpt-j n_rot = rotary_dim
@ -4644,6 +4634,17 @@ static void llm_load_hparams(
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head(); hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
// sanity check for n_rot (optional)
hparams.n_rot = hparams.n_embd_head_k;
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
if (hparams.n_rot != hparams.n_embd_head_k) {
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
}
}
} else { } else {
hparams.n_rot = 0; hparams.n_rot = 0;
hparams.n_embd_head_k = 0; hparams.n_embd_head_k = 0;
@ -11491,7 +11492,7 @@ struct llm_build_context {
Qcur = ggml_rope_ext( Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
@ -11500,7 +11501,7 @@ struct llm_build_context {
Kcur = ggml_rope_ext( Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr, ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
@ -11604,7 +11605,7 @@ struct llm_build_context {
Qcur = ggml_rope_ext( Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr, ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
@ -11613,7 +11614,7 @@ struct llm_build_context {
Kcur = ggml_rope_ext( Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr, ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);