llama : fix llm_build_k_shift args
This commit is contained in:
parent
84adb5412c
commit
5ed3e1a8f2
1 changed files with 4 additions and 20 deletions
24
llama.cpp
24
llama.cpp
|
@ -3459,8 +3459,7 @@ static void llm_build_k_shift(
|
|||
struct ggml_cgraph * graph,
|
||||
llm_rope_type type,
|
||||
int64_t n_ctx,
|
||||
int64_t n_rot,
|
||||
int n_dims,
|
||||
int n_rot,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
const llm_build_cb & cb) {
|
||||
|
@ -3492,32 +3491,17 @@ static void llm_build_k_shift(
|
|||
// we rotate only the first n_rot dimensions
|
||||
ggml_rope_custom_inplace(ctx,
|
||||
ggml_view_3d(ctx, kv.k,
|
||||
n_rot, n_head_kv, n_ctx,
|
||||
n_embd_head, n_head_kv, n_ctx,
|
||||
ggml_element_size(kv.k)*n_embd_head,
|
||||
ggml_element_size(kv.k)*n_embd_gqa,
|
||||
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
|
||||
K_shift, n_dims, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
K_shift, n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(tmp, "K_shifted", il);
|
||||
ggml_build_forward_expand(graph, tmp);
|
||||
}
|
||||
}
|
||||
|
||||
static void llm_build_k_shift(
|
||||
struct ggml_context * ctx,
|
||||
const llama_hparams & hparams,
|
||||
const llama_cparams & cparams,
|
||||
const llama_kv_cache & kv,
|
||||
struct ggml_cgraph * graph,
|
||||
llm_rope_type type,
|
||||
int64_t n_ctx,
|
||||
int64_t n_rot,
|
||||
float freq_base,
|
||||
float freq_scale,
|
||||
const llm_build_cb & cb) {
|
||||
llm_build_k_shift(ctx, hparams, cparams, kv, graph, type, n_ctx, n_rot, n_rot, freq_base, freq_scale, cb);
|
||||
}
|
||||
|
||||
static void llm_build_kv_store(
|
||||
struct ggml_context * ctx,
|
||||
const llama_hparams & hparams,
|
||||
|
@ -4814,7 +4798,7 @@ struct llm_build_context {
|
|||
|
||||
// shift the entire K-cache if needed
|
||||
if (do_rope_shift) {
|
||||
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, hparams.n_rot, freq_base, freq_scale, cb);
|
||||
llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, hparams.n_rot, freq_base, freq_scale, cb);
|
||||
}
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue