From 4543ed56402cb4e3e6f60aa655422e1a7cafd9e1 Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Tue, 10 Dec 2024 11:00:01 -0700 Subject: [PATCH] feat: Update the logic in llama_decode_internal for kv_hybrid cache Branch: BambaArchitecture Signed-off-by: Gabe Goodhart --- src/llama.cpp | 34 ++++++++++++++++++++++++++++++++-- 1 file changed, 32 insertions(+), 2 deletions(-) diff --git a/src/llama.cpp b/src/llama.cpp index e24864614..c09471aaf 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -18104,6 +18104,11 @@ static int llama_decode_internal( auto & kv_self = lctx.kv_self; llama_kv_slot_restorer kv_slot_restorer(kv_self); + // Only used for hybrid-recurrent models (e.g. Bamba) + const bool hybrid = llama_model_is_hybrid(&model); + auto & kv_hybrid = lctx.kv_hybrid; + llama_kv_slot_restorer kv_slot_restorer_hybrid(kv_hybrid); + const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -18192,7 +18197,15 @@ static int llama_decode_internal( return 1; } kv_slot_restorer.save(slot); + if (hybrid) { + const auto slot_hybrid = llama_kv_cache_find_slot(kv_hybrid, ubatch); + if (!slot_hybrid) { + return 1; + } + kv_slot_restorer_hybrid.save(slot_hybrid); + } + // TODO: Update this clause for hybrid recurrent models if (!kv_self.recurrent) { // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears @@ -18241,6 +18254,9 @@ static int llama_decode_internal( const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); if (compute_status != GGML_STATUS_SUCCESS) { kv_slot_restorer.restore(kv_self); + if (hybrid) { + kv_slot_restorer_hybrid.restore(kv_hybrid); + } switch (compute_status) { case GGML_STATUS_ABORTED: return 2; @@ -18252,7 +18268,7 @@ static int llama_decode_internal( } } - // update the kv ring buffer + // update the kv ring buffer(s) { kv_self.head += n_tokens; @@ -18260,6 +18276,13 @@ static int llama_decode_internal( if (kv_self.head >= kv_self.size) { kv_self.head = 0; } + + if (hybrid) { + kv_hybrid.head += n_tokens; + if (kv_hybrid.head >= kv_hybrid.size) { + kv_hybrid.head = 0; + } + } } // plot the computation graph in dot format (for debugging purposes) @@ -18366,7 +18389,7 @@ static int llama_decode_internal( // wait for the computation to finish (automatically done when obtaining the model output) //llama_synchronize(&lctx); - // decide if we need to defrag the kv cache + // decide if we need to defrag the kv cache(s) if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) { const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f; @@ -18376,6 +18399,13 @@ static int llama_decode_internal( llama_kv_cache_defrag(kv_self); } + + if (hybrid) { + const float fragmentation = kv_hybrid.n >= 128 ? 1.0f - float(kv_hybrid.used)/float(kv_hybrid.n) : 0.0f; + if (fragmentation > cparams.defrag_thold) { + llama_kv_cache_defrag(kv_hybrid); + } + } } // Reset state for the next token before backend sync, to allow the CPU activities in the reset to