feat: Update the logic in llama_decode_internal for kv_hybrid cache

Branch: BambaArchitecture

Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
Gabe Goodhart 2024-12-10 11:00:01 -07:00
parent 44bf431ab4
commit 4543ed5640

View file

@ -18104,6 +18104,11 @@ static int llama_decode_internal(
auto & kv_self = lctx.kv_self;
llama_kv_slot_restorer kv_slot_restorer(kv_self);
// Only used for hybrid-recurrent models (e.g. Bamba)
const bool hybrid = llama_model_is_hybrid(&model);
auto & kv_hybrid = lctx.kv_hybrid;
llama_kv_slot_restorer kv_slot_restorer_hybrid(kv_hybrid);
const int64_t n_embd = hparams.n_embd;
const int64_t n_vocab = hparams.n_vocab;
@ -18192,7 +18197,15 @@ static int llama_decode_internal(
return 1;
}
kv_slot_restorer.save(slot);
if (hybrid) {
const auto slot_hybrid = llama_kv_cache_find_slot(kv_hybrid, ubatch);
if (!slot_hybrid) {
return 1;
}
kv_slot_restorer_hybrid.save(slot_hybrid);
}
// TODO: Update this clause for hybrid recurrent models
if (!kv_self.recurrent) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
@ -18241,6 +18254,9 @@ static int llama_decode_internal(
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
if (compute_status != GGML_STATUS_SUCCESS) {
kv_slot_restorer.restore(kv_self);
if (hybrid) {
kv_slot_restorer_hybrid.restore(kv_hybrid);
}
switch (compute_status) {
case GGML_STATUS_ABORTED:
return 2;
@ -18252,7 +18268,7 @@ static int llama_decode_internal(
}
}
// update the kv ring buffer
// update the kv ring buffer(s)
{
kv_self.head += n_tokens;
@ -18260,6 +18276,13 @@ static int llama_decode_internal(
if (kv_self.head >= kv_self.size) {
kv_self.head = 0;
}
if (hybrid) {
kv_hybrid.head += n_tokens;
if (kv_hybrid.head >= kv_hybrid.size) {
kv_hybrid.head = 0;
}
}
}
// plot the computation graph in dot format (for debugging purposes)
@ -18366,7 +18389,7 @@ static int llama_decode_internal(
// wait for the computation to finish (automatically done when obtaining the model output)
//llama_synchronize(&lctx);
// decide if we need to defrag the kv cache
// decide if we need to defrag the kv cache(s)
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
@ -18376,6 +18399,13 @@ static int llama_decode_internal(
llama_kv_cache_defrag(kv_self);
}
if (hybrid) {
const float fragmentation = kv_hybrid.n >= 128 ? 1.0f - float(kv_hybrid.used)/float(kv_hybrid.n) : 0.0f;
if (fragmentation > cparams.defrag_thold) {
llama_kv_cache_defrag(kv_hybrid);
}
}
}
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to