ggml : refactor rope norm/neox (#7634)

* ggml : unify rope norm/neox (CPU)

* ggml : fix compile warning

* ggml : remove GLM rope mode

ggml-ci

* metal : better rope implementation

ggml-ci

* cuda : better rope implementation

ggml-ci

* naming : n_orig_ctx -> n_ctx_orig

ggml-ci

* dev : add reminders to update backends

ggml-ci

* vulkan : fix ggml_rope_ext() usage

* cuda : fix array size + indents

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-06-05 11:29:20 +03:00 committed by GitHub
parent 9973e81c5c
commit 2b3389677a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
19 changed files with 485 additions and 732 deletions

124
llama.cpp
View file

@ -1848,7 +1848,7 @@ struct llama_hparams {
float rope_attn_factor = 1.0f;
float rope_freq_base_train;
float rope_freq_scale_train;
uint32_t n_yarn_orig_ctx;
uint32_t n_ctx_orig_yarn;
float rope_yarn_log_mul;
// for State Space Models
@ -1890,7 +1890,7 @@ struct llama_hparams {
if (this->n_expert_shared != other.n_expert_shared) return true;
if (this->rope_finetuned != other.rope_finetuned) return true;
if (this->n_yarn_orig_ctx != other.n_yarn_orig_ctx) return true;
if (this->n_ctx_orig_yarn != other.n_ctx_orig_yarn) return true;
if (this->ssm_d_conv != other.ssm_d_conv) return true;
if (this->ssm_d_inner != other.ssm_d_inner) return true;
@ -1949,7 +1949,7 @@ struct llama_cparams {
float rope_freq_base;
float rope_freq_scale;
uint32_t n_yarn_orig_ctx;
uint32_t n_ctx_orig_yarn;
// These hyperparameters are not exposed in GGUF, because all
// existing YaRN models use the same values for them.
float yarn_ext_factor;
@ -4005,8 +4005,8 @@ static void llm_load_hparams(
ml.get_key(LLM_KV_ROPE_SCALING_FINETUNED, rope_finetuned, false);
hparams.rope_finetuned = rope_finetuned;
hparams.n_yarn_orig_ctx = hparams.n_ctx_train;
ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_yarn_orig_ctx, false);
hparams.n_ctx_orig_yarn = hparams.n_ctx_train;
ml.get_key(LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, hparams.n_ctx_orig_yarn, false);
// rope_freq_base (optional)
hparams.rope_freq_base_train = 10000.0f;
@ -4968,7 +4968,7 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type);
LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train);
LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train);
LLAMA_LOG_INFO("%s: n_yarn_orig_ctx = %u\n", __func__, hparams.n_yarn_orig_ctx);
LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn);
LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown");
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
@ -7134,7 +7134,7 @@ struct llm_build_context {
const int32_t n_kv; // size of KV cache to consider (n_kv <= kv_self.size)
const int32_t n_outputs;
const int32_t kv_head; // index of where we store new KV data in the cache
const int32_t n_orig_ctx;
const int32_t n_ctx_orig;
const bool flash_attn;
@ -7183,7 +7183,7 @@ struct llm_build_context {
n_kv (worst_case ? kv_self.size : kv_self.n),
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
n_orig_ctx (cparams.n_yarn_orig_ctx),
n_ctx_orig (cparams.n_ctx_orig_yarn),
flash_attn (cparams.flash_attn),
pooling_type (cparams.pooling_type),
rope_type (hparams.rope_type),
@ -7241,7 +7241,7 @@ struct llm_build_context {
ggml_row_size(kv_self.k_l[il]->type, n_embd_head_k),
ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa),
0),
lctx.inp_K_shift, rope_factors, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
lctx.inp_K_shift, rope_factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il);
@ -7350,7 +7350,7 @@ struct llm_build_context {
// choose long/short freq factors based on the context size
const auto n_ctx_pre_seq = cparams.n_ctx / cparams.n_seq_max;
if (n_ctx_pre_seq > hparams.n_yarn_orig_ctx) {
if (n_ctx_pre_seq > hparams.n_ctx_orig_yarn) {
return model.layers[il].rope_long;
}
@ -7466,14 +7466,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -7597,12 +7597,12 @@ struct llm_build_context {
case MODEL_7B:
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
break;
@ -7709,14 +7709,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -7829,13 +7829,13 @@ struct llm_build_context {
// using mode = 2 for neox mode
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -7953,14 +7953,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -8106,14 +8106,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -8460,14 +8460,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -8900,14 +8900,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -9019,13 +9019,13 @@ struct llm_build_context {
// using mode = 2 for neox mode
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -9131,14 +9131,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -9245,14 +9245,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -9397,7 +9397,7 @@ struct llm_build_context {
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
@ -9408,7 +9408,7 @@ struct llm_build_context {
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, 0, n_orig_ctx,
ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -9519,7 +9519,7 @@ struct llm_build_context {
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
Qcur = ggml_rope_ext(
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
ctx0, Qcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
@ -9528,7 +9528,7 @@ struct llm_build_context {
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, 0, n_orig_ctx,
ctx0, Kcur, inp_pos, rope_factors, n_rot, rope_type, n_ctx_orig,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -9636,13 +9636,13 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_rot, n_head, n_tokens), inp_pos, nullptr,
n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_rot, n_head_kv, n_tokens), inp_pos, nullptr,
n_embd_head, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_embd_head, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
@ -9844,14 +9844,14 @@ struct llm_build_context {
struct ggml_tensor * Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, tmpq, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, tmpk, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -9960,14 +9960,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -10077,14 +10077,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -10207,14 +10207,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -10327,7 +10327,7 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head_k, n_head, n_tokens), inp_pos, nullptr,
n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Qcur, "Qcur", il);
@ -10336,7 +10336,7 @@ struct llm_build_context {
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head_k, n_head_kv, n_tokens), inp_pos, nullptr,
n_embd_head_k, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_embd_head_k, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow);
cb(Kcur, "Kcur", il);
@ -10447,14 +10447,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -10737,14 +10737,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -10868,14 +10868,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -10982,14 +10982,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -11117,14 +11117,14 @@ struct llm_build_context {
Qcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Qcur, "Qcur", il);
Kcur = ggml_rope_ext(
ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow
);
cb(Kcur, "Kcur", il);
@ -11334,7 +11334,7 @@ struct llm_build_context {
q_pe = ggml_cont(ctx0, q_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
q_pe = ggml_rope_ext(
ctx0, q_pe, inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow
);
cb(q_pe, "q_pe", il);
@ -11343,7 +11343,7 @@ struct llm_build_context {
k_pe = ggml_cont(ctx0, k_pe); // TODO: the CUDA backend does not support non-contiguous RoPE
k_pe = ggml_rope_ext(
ctx0, k_pe, inp_pos, nullptr,
n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
ext_factor, attn_factor_scaled, beta_fast, beta_slow
);
cb(k_pe, "k_pe", il);
@ -16067,8 +16067,8 @@ struct llama_context * llama_new_context_with_model(
cparams.n_ubatch = std::min(cparams.n_batch, params.n_ubatch == 0 ? params.n_batch : params.n_ubatch);
cparams.n_yarn_orig_ctx = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
hparams.n_yarn_orig_ctx != 0 ? hparams.n_yarn_orig_ctx :
cparams.n_ctx_orig_yarn = params.yarn_orig_ctx != 0 ? params.yarn_orig_ctx :
hparams.n_ctx_orig_yarn != 0 ? hparams.n_ctx_orig_yarn :
hparams.n_ctx_train;
cparams.cb_eval = params.cb_eval;