parent
d39130a398
commit
b5040086d4
1 changed files with 15 additions and 14 deletions
|
@ -4625,16 +4625,6 @@ static void llm_load_hparams(
|
||||||
|
|
||||||
// non-transformer models do not have attention heads
|
// non-transformer models do not have attention heads
|
||||||
if (hparams.n_head() > 0) {
|
if (hparams.n_head() > 0) {
|
||||||
// sanity check for n_rot (optional)
|
|
||||||
hparams.n_rot = hparams.n_embd / hparams.n_head();
|
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
|
||||||
|
|
||||||
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
|
|
||||||
if (hparams.n_rot != hparams.n_embd / hparams.n_head()) {
|
|
||||||
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd / hparams.n_head()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
|
// gpt-neox n_rot = rotary_pct * (n_embd / n_head)
|
||||||
// gpt-j n_rot = rotary_dim
|
// gpt-j n_rot = rotary_dim
|
||||||
|
|
||||||
|
@ -4643,6 +4633,17 @@ static void llm_load_hparams(
|
||||||
|
|
||||||
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
|
hparams.n_embd_head_v = hparams.n_embd / hparams.n_head();
|
||||||
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false);
|
||||||
|
|
||||||
|
// sanity check for n_rot (optional)
|
||||||
|
hparams.n_rot = hparams.n_embd_head_k;
|
||||||
|
|
||||||
|
ml.get_key(LLM_KV_ROPE_DIMENSION_COUNT, hparams.n_rot, false);
|
||||||
|
|
||||||
|
if (model.arch == LLM_ARCH_LLAMA || model.arch == LLM_ARCH_FALCON) {
|
||||||
|
if (hparams.n_rot != hparams.n_embd_head_k) {
|
||||||
|
throw std::runtime_error(format("invalid n_rot: %u, expected %u", hparams.n_rot, hparams.n_embd_head_k));
|
||||||
|
}
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
hparams.n_rot = 0;
|
hparams.n_rot = 0;
|
||||||
hparams.n_embd_head_k = 0;
|
hparams.n_embd_head_k = 0;
|
||||||
|
@ -11490,7 +11491,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
@ -11499,7 +11500,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
@ -11603,7 +11604,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
Qcur = ggml_rope_ext(
|
Qcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Qcur, "Qcur", il);
|
cb(Qcur, "Qcur", il);
|
||||||
|
|
||||||
|
@ -11612,7 +11613,7 @@ struct llm_build_context {
|
||||||
|
|
||||||
Kcur = ggml_rope_ext(
|
Kcur = ggml_rope_ext(
|
||||||
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
|
||||||
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
|
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||||
cb(Kcur, "Kcur", il);
|
cb(Kcur, "Kcur", il);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue