fix: A number of places where hybrid needs to be handled

Still not fully working, but worth committing these:

* per-layer n_embd_[kv]_s (probably a no-op since first layer is ssm)
* fix setting n_kv_hybrid when not worst_case
* Use the right n_kv for build_inp_s_copy when hybrid
* Use the right n_kv for recurrent section of llama_set_inputs
* Use the right logic to determine batch splitting for hybrid

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-12-10 15:34:53 -07:00
parent 4543ed5640
commit 204e78fba1

View file

@ -10460,11 +10460,11 @@ static struct ggml_tensor * llm_build_mamba2(
// (ab)using the KV cache to store the states
struct ggml_tensor * conv = llm_build_rs(ctx,
graph, conv_states_all, state_copy, rs_zero,
hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
hparams.n_embd_k_s(il), kv.size, kv_head, n_kv, n_seqs);
conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
struct ggml_tensor * ssm = llm_build_rs(ctx,
graph, ssm_states_all, state_copy, rs_zero,
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true);
hparams.n_embd_v_s(il), kv.size, kv_head, n_kv, n_seqs, true);
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size);
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
@ -10808,7 +10808,7 @@ struct llm_build_context {
norm_rms_eps (hparams.f_norm_rms_eps),
n_tokens (ubatch.n_tokens),
n_kv (worst_case ? kv_self.size : kv_self.n),
n_kv_hybrid (worst_case ? kv_hybrid.size : kv_self.n),
n_kv_hybrid (worst_case ? kv_hybrid.size : kv_hybrid.n),
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
@ -11036,8 +11036,8 @@ struct llm_build_context {
return lctx.inp_cls;
}
struct ggml_tensor * build_inp_s_copy() {
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
struct ggml_tensor * build_inp_s_copy(bool hybrid = false) {
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, hybrid ? n_kv_hybrid : n_kv);
cb(lctx.inp_s_copy, "inp_s_copy", -1);
ggml_set_input(lctx.inp_s_copy);
return lctx.inp_s_copy;
@ -14686,7 +14686,7 @@ struct llm_build_context {
// {n_embd, n_tokens}
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
struct ggml_tensor * state_copy = build_inp_s_copy();
struct ggml_tensor * state_copy = build_inp_s_copy(/* hybrid */true);
const int64_t n_embd_head = hparams.n_embd_head_v;
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
@ -14710,7 +14710,8 @@ struct llm_build_context {
if (hparams.recurrent_layer(il)) {
// ssm layer
cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy,
rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, true);
rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il,
/* hybrid */ true);
cb(cur, "mamba_out", il);
} else {
// attention layer //
@ -17813,8 +17814,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
}
}
if (kv_self.recurrent) {
const int64_t n_kv = kv_self.n;
const bool hybrid = llama_model_is_hybrid(&lctx.model);
auto& kv_hybrid = lctx.kv_hybrid;
if (kv_self.recurrent || (hybrid && kv_hybrid.recurrent)) {
auto& kv_recurrent = hybrid ? kv_hybrid : lctx.kv_self;
const int64_t n_kv = kv_recurrent.n;
if (lctx.inp_s_copy) {
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
@ -17822,14 +17826,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
for (uint32_t i = 0; i < n_kv; ++i) {
const uint32_t cell_id = i + kv_self.head;
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
const uint32_t cell_id = i + kv_recurrent.head;
llama_kv_cell & kv_cell = kv_recurrent.cells[cell_id];
if (kv_cell.src < 0) {
GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source
kv_cell.src = kv_self.rs_z;
GGML_ASSERT(kv_recurrent.rs_z >= 0); // Need a valid zero-ed cell as a source
kv_cell.src = kv_recurrent.rs_z;
}
if ((uint32_t) kv_cell.src >= kv_self.size) {
if ((uint32_t) kv_cell.src >= kv_recurrent.size) {
// ignore out-of-bound sources
kv_cell.src = cell_id;
}
@ -18135,7 +18139,7 @@ static int llama_decode_internal(
}
lctx.sbatch.from_batch(batch, n_embd,
/* simple_split */ !kv_self.recurrent,
/* simple_split */ !(kv_self.recurrent || (hybrid && kv_hybrid.recurrent)),
/* logits_all */ n_outputs == n_tokens_all);
// reserve output buffer
@ -18146,7 +18150,7 @@ static int llama_decode_internal(
while (lctx.sbatch.n_tokens > 0) {
llama_ubatch ubatch;
if (kv_self.recurrent) {
if (kv_self.recurrent || (hybrid && kv_hybrid.recurrent)) {
if (embd_pooled) {
// Pooled embeddings cannot be split across ubatches (yet)
ubatch = lctx.sbatch.split_seq(n_ubatch);