feat: Update the logic in llama_decode_internal for kv_hybrid cache

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-12-10 11:00:01 -07:00
parent 44bf431ab4
commit 4543ed5640

View file

@ -18104,6 +18104,11 @@ static int llama_decode_internal(
auto & kv_self = lctx.kv_self; auto & kv_self = lctx.kv_self;
llama_kv_slot_restorer kv_slot_restorer(kv_self); llama_kv_slot_restorer kv_slot_restorer(kv_self);
// Only used for hybrid-recurrent models (e.g. Bamba)
const bool hybrid = llama_model_is_hybrid(&model);
auto & kv_hybrid = lctx.kv_hybrid;
llama_kv_slot_restorer kv_slot_restorer_hybrid(kv_hybrid);
const int64_t n_embd = hparams.n_embd; const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab; const int64_t n_vocab = hparams.n_vocab;
@ -18192,7 +18197,15 @@ static int llama_decode_internal(
return 1; return 1;
} }
kv_slot_restorer.save(slot); kv_slot_restorer.save(slot);
if (hybrid) {
const auto slot_hybrid = llama_kv_cache_find_slot(kv_hybrid, ubatch);
if (!slot_hybrid) {
return 1;
}
kv_slot_restorer_hybrid.save(slot_hybrid);
}
// TODO: Update this clause for hybrid recurrent models
if (!kv_self.recurrent) { if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized // a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears // after enough generations, the benefit from this heuristic disappears
@ -18241,6 +18254,9 @@ static int llama_decode_internal(
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
if (compute_status != GGML_STATUS_SUCCESS) { if (compute_status != GGML_STATUS_SUCCESS) {
kv_slot_restorer.restore(kv_self); kv_slot_restorer.restore(kv_self);
if (hybrid) {
kv_slot_restorer_hybrid.restore(kv_hybrid);
}
switch (compute_status) { switch (compute_status) {
case GGML_STATUS_ABORTED: case GGML_STATUS_ABORTED:
return 2; return 2;
@ -18252,7 +18268,7 @@ static int llama_decode_internal(
} }
} }
// update the kv ring buffer // update the kv ring buffer(s)
{ {
kv_self.head += n_tokens; kv_self.head += n_tokens;
@ -18260,6 +18276,13 @@ static int llama_decode_internal(
if (kv_self.head >= kv_self.size) { if (kv_self.head >= kv_self.size) {
kv_self.head = 0; kv_self.head = 0;
} }
if (hybrid) {
kv_hybrid.head += n_tokens;
if (kv_hybrid.head >= kv_hybrid.size) {
kv_hybrid.head = 0;
}
}
} }
// plot the computation graph in dot format (for debugging purposes) // plot the computation graph in dot format (for debugging purposes)
@ -18366,7 +18389,7 @@ static int llama_decode_internal(
// wait for the computation to finish (automatically done when obtaining the model output) // wait for the computation to finish (automatically done when obtaining the model output)
//llama_synchronize(&lctx); //llama_synchronize(&lctx);
// decide if we need to defrag the kv cache // decide if we need to defrag the kv cache(s)
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) { if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f; const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
@ -18376,6 +18399,13 @@ static int llama_decode_internal(
llama_kv_cache_defrag(kv_self); llama_kv_cache_defrag(kv_self);
} }
if (hybrid) {
const float fragmentation = kv_hybrid.n >= 128 ? 1.0f - float(kv_hybrid.used)/float(kv_hybrid.n) : 0.0f;
if (fragmentation > cparams.defrag_thold) {
llama_kv_cache_defrag(kv_hybrid);
}
}
} }
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to // Reset state for the next token before backend sync, to allow the CPU activities in the reset to