llama : fix llm_build_k_shift to use correct n_rot

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-01-12 10:46:39 +02:00
parent 7edefbd79c
commit ff0899c9b3
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 12 additions and 11 deletions

View file

@ -1055,6 +1055,9 @@ struct llama_model_params llama_model_params_from_gpt_params(const gpt_params &
}
static ggml_type kv_cache_type_from_str(const std::string & s) {
if (s == "f32") {
return GGML_TYPE_F32;
}
if (s == "f16") {
return GGML_TYPE_F16;
}

View file

@ -4104,7 +4104,6 @@ static void llm_build_k_shift(
struct ggml_cgraph * graph,
llm_rope_type type,
int64_t n_ctx,
int n_rot,
float freq_base,
float freq_scale,
const llm_build_cb & cb) {
@ -4112,14 +4111,13 @@ static void llm_build_k_shift(
const int64_t n_head_kv = hparams.n_head_kv;
const int64_t n_embd_head_k = hparams.n_embd_head_k;
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
const int32_t n_rot = hparams.n_rot;
const int32_t n_orig_ctx = cparams.n_yarn_orig_ctx;
const float ext_factor = cparams.yarn_ext_factor;
const float attn_factor = cparams.yarn_attn_factor;
const float beta_fast = cparams.yarn_beta_fast;
const float beta_slow = cparams.yarn_beta_slow;
GGML_ASSERT(n_embd_head_k % n_rot == 0);
struct ggml_tensor * K_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_ctx);
cb(K_shift, "K_shift", -1);
@ -4523,7 +4521,7 @@ struct llm_build_context {
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@ -4708,7 +4706,7 @@ struct llm_build_context {
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@ -4829,7 +4827,7 @@ struct llm_build_context {
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@ -5052,7 +5050,7 @@ struct llm_build_context {
cb(KQ_mask, "KQ_mask", -1);
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@ -5548,7 +5546,7 @@ struct llm_build_context {
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb);
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@ -5661,7 +5659,7 @@ struct llm_build_context {
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@ -5778,7 +5776,7 @@ struct llm_build_context {
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {
@ -5891,7 +5889,7 @@ struct llm_build_context {
// shift the entire K-cache if needed
if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, n_embd_head, freq_base, freq_scale, cb);
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
}
for (int il = 0; il < n_layer; ++il) {