fix: A number of places where hybrid needs to be handled
Still not fully working, but worth committing these: * per-layer n_embd_[kv]_s (probably a no-op since first layer is ssm) * fix setting n_kv_hybrid when not worst_case * Use the right n_kv for build_inp_s_copy when hybrid * Use the right n_kv for recurrent section of llama_set_inputs * Use the right logic to determine batch splitting for hybrid Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
4543ed5640
commit
204e78fba1
1 changed files with 20 additions and 16 deletions
|
@ -10460,11 +10460,11 @@ static struct ggml_tensor * llm_build_mamba2(
|
|||
// (ab)using the KV cache to store the states
|
||||
struct ggml_tensor * conv = llm_build_rs(ctx,
|
||||
graph, conv_states_all, state_copy, rs_zero,
|
||||
hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
|
||||
hparams.n_embd_k_s(il), kv.size, kv_head, n_kv, n_seqs);
|
||||
conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
|
||||
struct ggml_tensor * ssm = llm_build_rs(ctx,
|
||||
graph, ssm_states_all, state_copy, rs_zero,
|
||||
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true);
|
||||
hparams.n_embd_v_s(il), kv.size, kv_head, n_kv, n_seqs, true);
|
||||
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size);
|
||||
|
||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||
|
@ -10808,7 +10808,7 @@ struct llm_build_context {
|
|||
norm_rms_eps (hparams.f_norm_rms_eps),
|
||||
n_tokens (ubatch.n_tokens),
|
||||
n_kv (worst_case ? kv_self.size : kv_self.n),
|
||||
n_kv_hybrid (worst_case ? kv_hybrid.size : kv_self.n),
|
||||
n_kv_hybrid (worst_case ? kv_hybrid.size : kv_hybrid.n),
|
||||
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
||||
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
|
||||
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
||||
|
@ -11036,8 +11036,8 @@ struct llm_build_context {
|
|||
return lctx.inp_cls;
|
||||
}
|
||||
|
||||
struct ggml_tensor * build_inp_s_copy() {
|
||||
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
||||
struct ggml_tensor * build_inp_s_copy(bool hybrid = false) {
|
||||
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, hybrid ? n_kv_hybrid : n_kv);
|
||||
cb(lctx.inp_s_copy, "inp_s_copy", -1);
|
||||
ggml_set_input(lctx.inp_s_copy);
|
||||
return lctx.inp_s_copy;
|
||||
|
@ -14686,7 +14686,7 @@ struct llm_build_context {
|
|||
// {n_embd, n_tokens}
|
||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||
|
||||
struct ggml_tensor * state_copy = build_inp_s_copy();
|
||||
struct ggml_tensor * state_copy = build_inp_s_copy(/* hybrid */true);
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
@ -14710,7 +14710,8 @@ struct llm_build_context {
|
|||
if (hparams.recurrent_layer(il)) {
|
||||
// ssm layer
|
||||
cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy,
|
||||
rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, true);
|
||||
rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il,
|
||||
/* hybrid */ true);
|
||||
cb(cur, "mamba_out", il);
|
||||
} else {
|
||||
// attention layer //
|
||||
|
@ -17813,8 +17814,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
|
|||
}
|
||||
}
|
||||
|
||||
if (kv_self.recurrent) {
|
||||
const int64_t n_kv = kv_self.n;
|
||||
const bool hybrid = llama_model_is_hybrid(&lctx.model);
|
||||
auto& kv_hybrid = lctx.kv_hybrid;
|
||||
if (kv_self.recurrent || (hybrid && kv_hybrid.recurrent)) {
|
||||
auto& kv_recurrent = hybrid ? kv_hybrid : lctx.kv_self;
|
||||
const int64_t n_kv = kv_recurrent.n;
|
||||
|
||||
if (lctx.inp_s_copy) {
|
||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
|
||||
|
@ -17822,14 +17826,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
|
|||
|
||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
const uint32_t cell_id = i + kv_self.head;
|
||||
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
||||
const uint32_t cell_id = i + kv_recurrent.head;
|
||||
llama_kv_cell & kv_cell = kv_recurrent.cells[cell_id];
|
||||
|
||||
if (kv_cell.src < 0) {
|
||||
GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source
|
||||
kv_cell.src = kv_self.rs_z;
|
||||
GGML_ASSERT(kv_recurrent.rs_z >= 0); // Need a valid zero-ed cell as a source
|
||||
kv_cell.src = kv_recurrent.rs_z;
|
||||
}
|
||||
if ((uint32_t) kv_cell.src >= kv_self.size) {
|
||||
if ((uint32_t) kv_cell.src >= kv_recurrent.size) {
|
||||
// ignore out-of-bound sources
|
||||
kv_cell.src = cell_id;
|
||||
}
|
||||
|
@ -18135,7 +18139,7 @@ static int llama_decode_internal(
|
|||
}
|
||||
|
||||
lctx.sbatch.from_batch(batch, n_embd,
|
||||
/* simple_split */ !kv_self.recurrent,
|
||||
/* simple_split */ !(kv_self.recurrent || (hybrid && kv_hybrid.recurrent)),
|
||||
/* logits_all */ n_outputs == n_tokens_all);
|
||||
|
||||
// reserve output buffer
|
||||
|
@ -18146,7 +18150,7 @@ static int llama_decode_internal(
|
|||
|
||||
while (lctx.sbatch.n_tokens > 0) {
|
||||
llama_ubatch ubatch;
|
||||
if (kv_self.recurrent) {
|
||||
if (kv_self.recurrent || (hybrid && kv_hybrid.recurrent)) {
|
||||
if (embd_pooled) {
|
||||
// Pooled embeddings cannot be split across ubatches (yet)
|
||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue