fix: A number of places where hybrid needs to be handled
Still not fully working, but worth committing these: * per-layer n_embd_[kv]_s (probably a no-op since first layer is ssm) * fix setting n_kv_hybrid when not worst_case * Use the right n_kv for build_inp_s_copy when hybrid * Use the right n_kv for recurrent section of llama_set_inputs * Use the right logic to determine batch splitting for hybrid Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
4543ed5640
commit
204e78fba1
1 changed files with 20 additions and 16 deletions
|
@ -10460,11 +10460,11 @@ static struct ggml_tensor * llm_build_mamba2(
|
||||||
// (ab)using the KV cache to store the states
|
// (ab)using the KV cache to store the states
|
||||||
struct ggml_tensor * conv = llm_build_rs(ctx,
|
struct ggml_tensor * conv = llm_build_rs(ctx,
|
||||||
graph, conv_states_all, state_copy, rs_zero,
|
graph, conv_states_all, state_copy, rs_zero,
|
||||||
hparams.n_embd_k_s(), kv.size, kv_head, n_kv, n_seqs);
|
hparams.n_embd_k_s(il), kv.size, kv_head, n_kv, n_seqs);
|
||||||
conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
|
conv = ggml_reshape_3d(ctx, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs);
|
||||||
struct ggml_tensor * ssm = llm_build_rs(ctx,
|
struct ggml_tensor * ssm = llm_build_rs(ctx,
|
||||||
graph, ssm_states_all, state_copy, rs_zero,
|
graph, ssm_states_all, state_copy, rs_zero,
|
||||||
hparams.n_embd_v_s(), kv.size, kv_head, n_kv, n_seqs, true);
|
hparams.n_embd_v_s(il), kv.size, kv_head, n_kv, n_seqs, true);
|
||||||
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size);
|
ssm = ggml_reshape_4d(ctx, ssm, d_state, head_dim, n_head, kv.size);
|
||||||
|
|
||||||
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
// {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs}
|
||||||
|
@ -10808,7 +10808,7 @@ struct llm_build_context {
|
||||||
norm_rms_eps (hparams.f_norm_rms_eps),
|
norm_rms_eps (hparams.f_norm_rms_eps),
|
||||||
n_tokens (ubatch.n_tokens),
|
n_tokens (ubatch.n_tokens),
|
||||||
n_kv (worst_case ? kv_self.size : kv_self.n),
|
n_kv (worst_case ? kv_self.size : kv_self.n),
|
||||||
n_kv_hybrid (worst_case ? kv_hybrid.size : kv_self.n),
|
n_kv_hybrid (worst_case ? kv_hybrid.size : kv_hybrid.n),
|
||||||
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
n_outputs (worst_case ? n_tokens : lctx.n_outputs),
|
||||||
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
|
n_outputs_enc (worst_case ? n_tokens : lctx.embd_enc.size() / hparams.n_embd),
|
||||||
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
kv_head (worst_case ? (kv_self.recurrent ? 0 : kv_self.size - n_tokens) : kv_self.head),
|
||||||
|
@ -11036,8 +11036,8 @@ struct llm_build_context {
|
||||||
return lctx.inp_cls;
|
return lctx.inp_cls;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_tensor * build_inp_s_copy() {
|
struct ggml_tensor * build_inp_s_copy(bool hybrid = false) {
|
||||||
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_kv);
|
lctx.inp_s_copy = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, hybrid ? n_kv_hybrid : n_kv);
|
||||||
cb(lctx.inp_s_copy, "inp_s_copy", -1);
|
cb(lctx.inp_s_copy, "inp_s_copy", -1);
|
||||||
ggml_set_input(lctx.inp_s_copy);
|
ggml_set_input(lctx.inp_s_copy);
|
||||||
return lctx.inp_s_copy;
|
return lctx.inp_s_copy;
|
||||||
|
@ -14686,7 +14686,7 @@ struct llm_build_context {
|
||||||
// {n_embd, n_tokens}
|
// {n_embd, n_tokens}
|
||||||
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
inpL = llm_build_inp_embd(ctx0, lctx, hparams, ubatch, model.tok_embd, cb);
|
||||||
|
|
||||||
struct ggml_tensor * state_copy = build_inp_s_copy();
|
struct ggml_tensor * state_copy = build_inp_s_copy(/* hybrid */true);
|
||||||
|
|
||||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||||
|
@ -14710,7 +14710,8 @@ struct llm_build_context {
|
||||||
if (hparams.recurrent_layer(il)) {
|
if (hparams.recurrent_layer(il)) {
|
||||||
// ssm layer
|
// ssm layer
|
||||||
cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy,
|
cur = llm_build_mamba2(ctx0, lctx, ubatch, gf, cur, state_copy,
|
||||||
rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il, true);
|
rs_zero_hybrid, kv_head_hybrid, n_kv_hybrid, cb, il,
|
||||||
|
/* hybrid */ true);
|
||||||
cb(cur, "mamba_out", il);
|
cb(cur, "mamba_out", il);
|
||||||
} else {
|
} else {
|
||||||
// attention layer //
|
// attention layer //
|
||||||
|
@ -17813,8 +17814,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (kv_self.recurrent) {
|
const bool hybrid = llama_model_is_hybrid(&lctx.model);
|
||||||
const int64_t n_kv = kv_self.n;
|
auto& kv_hybrid = lctx.kv_hybrid;
|
||||||
|
if (kv_self.recurrent || (hybrid && kv_hybrid.recurrent)) {
|
||||||
|
auto& kv_recurrent = hybrid ? kv_hybrid : lctx.kv_self;
|
||||||
|
const int64_t n_kv = kv_recurrent.n;
|
||||||
|
|
||||||
if (lctx.inp_s_copy) {
|
if (lctx.inp_s_copy) {
|
||||||
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
|
GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_s_copy->buffer));
|
||||||
|
@ -17822,14 +17826,14 @@ static void llama_set_inputs(llama_context & lctx, const llama_ubatch & ubatch)
|
||||||
|
|
||||||
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
// assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n
|
||||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||||
const uint32_t cell_id = i + kv_self.head;
|
const uint32_t cell_id = i + kv_recurrent.head;
|
||||||
llama_kv_cell & kv_cell = lctx.kv_self.cells[cell_id];
|
llama_kv_cell & kv_cell = kv_recurrent.cells[cell_id];
|
||||||
|
|
||||||
if (kv_cell.src < 0) {
|
if (kv_cell.src < 0) {
|
||||||
GGML_ASSERT(kv_self.rs_z >= 0); // Need a valid zero-ed cell as a source
|
GGML_ASSERT(kv_recurrent.rs_z >= 0); // Need a valid zero-ed cell as a source
|
||||||
kv_cell.src = kv_self.rs_z;
|
kv_cell.src = kv_recurrent.rs_z;
|
||||||
}
|
}
|
||||||
if ((uint32_t) kv_cell.src >= kv_self.size) {
|
if ((uint32_t) kv_cell.src >= kv_recurrent.size) {
|
||||||
// ignore out-of-bound sources
|
// ignore out-of-bound sources
|
||||||
kv_cell.src = cell_id;
|
kv_cell.src = cell_id;
|
||||||
}
|
}
|
||||||
|
@ -18135,7 +18139,7 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
|
|
||||||
lctx.sbatch.from_batch(batch, n_embd,
|
lctx.sbatch.from_batch(batch, n_embd,
|
||||||
/* simple_split */ !kv_self.recurrent,
|
/* simple_split */ !(kv_self.recurrent || (hybrid && kv_hybrid.recurrent)),
|
||||||
/* logits_all */ n_outputs == n_tokens_all);
|
/* logits_all */ n_outputs == n_tokens_all);
|
||||||
|
|
||||||
// reserve output buffer
|
// reserve output buffer
|
||||||
|
@ -18146,7 +18150,7 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
while (lctx.sbatch.n_tokens > 0) {
|
while (lctx.sbatch.n_tokens > 0) {
|
||||||
llama_ubatch ubatch;
|
llama_ubatch ubatch;
|
||||||
if (kv_self.recurrent) {
|
if (kv_self.recurrent || (hybrid && kv_hybrid.recurrent)) {
|
||||||
if (embd_pooled) {
|
if (embd_pooled) {
|
||||||
// Pooled embeddings cannot be split across ubatches (yet)
|
// Pooled embeddings cannot be split across ubatches (yet)
|
||||||
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
ubatch = lctx.sbatch.split_seq(n_ubatch);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue