mamba : move state_seq and state_mask views outside layer loop
A few tensors were also missing `struct` in front of `ggml_tensor`.
This commit is contained in:
parent
3e5685f7ea
commit
39579d3ceb
1 changed files with 9 additions and 7 deletions
16
llama.cpp
16
llama.cpp
|
@ -5540,9 +5540,11 @@ struct llm_build_context {
|
|||
struct ggml_cgraph * build_s_copy() {
|
||||
struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
|
||||
|
||||
GGML_ASSERT(kv_self.recurrent);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
|
||||
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
|
||||
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
|
||||
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
|
||||
|
||||
conv_states = ggml_get_rows(ctx0, conv_states, lctx.inp_s_copy);
|
||||
ssm_states = ggml_get_rows(ctx0, ssm_states, lctx.inp_s_copy);
|
||||
|
@ -8171,14 +8173,16 @@ struct llm_build_context {
|
|||
inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, lctx.inp_tokens, lctx.inp_embd, cb);
|
||||
cb(inpL, "inp_embd", -1);
|
||||
|
||||
struct ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
|
||||
struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
// (ab)using the KV cache to store the states
|
||||
ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
|
||||
ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
|
||||
struct ggml_tensor * conv_states = ggml_reshape_2d(ctx0, kv_self.k_l[il], hparams.n_embd_k_s(), kv_self.size);
|
||||
struct ggml_tensor * ssm_states = ggml_reshape_2d(ctx0, kv_self.v_l[il], hparams.n_embd_v_s(), kv_self.size);
|
||||
|
||||
// clear states of sequences which are starting at the beginning of this batch
|
||||
{
|
||||
ggml_tensor * state_mask = ggml_view_2d(ctx0, lctx.inp_s_mask, 1, n_kv, lctx.inp_s_mask->nb[0], 0);
|
||||
conv_states = ggml_mul(ctx0,
|
||||
ggml_view_2d(ctx0, conv_states, conv_states->ne[0], n_kv, conv_states->nb[1], kv_head*conv_states->nb[1]),
|
||||
state_mask);
|
||||
|
@ -8203,8 +8207,6 @@ struct llm_build_context {
|
|||
struct ggml_tensor * x = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], 0);
|
||||
struct ggml_tensor * z = ggml_view_2d(ctx0, xz, d_inner, xz->ne[1], xz->nb[1], ggml_element_size(xz)*d_inner);
|
||||
|
||||
struct ggml_tensor * state_seq = ggml_view_2d(ctx0, lctx.inp_s_seq, n_kv, n_tokens, n_kv*ggml_element_size(lctx.inp_s_seq), 0);
|
||||
|
||||
// conv
|
||||
{
|
||||
// Custom operator which is needed only to ease simultaneous sequence processing.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue