add n_dims parameter to llm_build_k_shift, default to n_rot via overload

This commit is contained in:
slaren 2023-11-22 00:45:43 +01:00
parent 4a3469f20e
commit 84adb5412c

View file

@ -3460,6 +3460,7 @@ static void llm_build_k_shift(
llm_rope_type type, llm_rope_type type,
int64_t n_ctx, int64_t n_ctx,
int64_t n_rot, int64_t n_rot,
int n_dims,
float freq_base, float freq_base,
float freq_scale, float freq_scale,
const llm_build_cb & cb) { const llm_build_cb & cb) {
@ -3495,13 +3496,28 @@ static void llm_build_k_shift(
ggml_element_size(kv.k)*n_embd_head, ggml_element_size(kv.k)*n_embd_head,
ggml_element_size(kv.k)*n_embd_gqa, ggml_element_size(kv.k)*n_embd_gqa,
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il), ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale, K_shift, n_dims, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
ext_factor, attn_factor, beta_fast, beta_slow); ext_factor, attn_factor, beta_fast, beta_slow);
cb(tmp, "K_shifted", il); cb(tmp, "K_shifted", il);
ggml_build_forward_expand(graph, tmp); ggml_build_forward_expand(graph, tmp);
} }
} }
static void llm_build_k_shift(
struct ggml_context * ctx,
const llama_hparams & hparams,
const llama_cparams & cparams,
const llama_kv_cache & kv,
struct ggml_cgraph * graph,
llm_rope_type type,
int64_t n_ctx,
int64_t n_rot,
float freq_base,
float freq_scale,
const llm_build_cb & cb) {
llm_build_k_shift(ctx, hparams, cparams, kv, graph, type, n_ctx, n_rot, n_rot, freq_base, freq_scale, cb);
}
static void llm_build_kv_store( static void llm_build_kv_store(
struct ggml_context * ctx, struct ggml_context * ctx,
const llama_hparams & hparams, const llama_hparams & hparams,
@ -4798,7 +4814,7 @@ struct llm_build_context {
// shift the entire K-cache if needed // shift the entire K-cache if needed
if (do_rope_shift) { if (do_rope_shift) {
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb); llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, hparams.n_rot, freq_base, freq_scale, cb);
} }
for (int il = 0; il < n_layer; ++il) { for (int il = 0; il < n_layer; ++il) {