feat: Update the logic in llama_decode_internal for kv_hybrid cache
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
44bf431ab4
commit
4543ed5640
1 changed files with 32 additions and 2 deletions
|
@ -18104,6 +18104,11 @@ static int llama_decode_internal(
|
|||
auto & kv_self = lctx.kv_self;
|
||||
llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
||||
|
||||
// Only used for hybrid-recurrent models (e.g. Bamba)
|
||||
const bool hybrid = llama_model_is_hybrid(&model);
|
||||
auto & kv_hybrid = lctx.kv_hybrid;
|
||||
llama_kv_slot_restorer kv_slot_restorer_hybrid(kv_hybrid);
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
|
||||
|
@ -18192,7 +18197,15 @@ static int llama_decode_internal(
|
|||
return 1;
|
||||
}
|
||||
kv_slot_restorer.save(slot);
|
||||
if (hybrid) {
|
||||
const auto slot_hybrid = llama_kv_cache_find_slot(kv_hybrid, ubatch);
|
||||
if (!slot_hybrid) {
|
||||
return 1;
|
||||
}
|
||||
kv_slot_restorer_hybrid.save(slot_hybrid);
|
||||
}
|
||||
|
||||
// TODO: Update this clause for hybrid recurrent models
|
||||
if (!kv_self.recurrent) {
|
||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||
// after enough generations, the benefit from this heuristic disappears
|
||||
|
@ -18241,6 +18254,9 @@ static int llama_decode_internal(
|
|||
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
kv_slot_restorer.restore(kv_self);
|
||||
if (hybrid) {
|
||||
kv_slot_restorer_hybrid.restore(kv_hybrid);
|
||||
}
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
|
@ -18252,7 +18268,7 @@ static int llama_decode_internal(
|
|||
}
|
||||
}
|
||||
|
||||
// update the kv ring buffer
|
||||
// update the kv ring buffer(s)
|
||||
{
|
||||
kv_self.head += n_tokens;
|
||||
|
||||
|
@ -18260,6 +18276,13 @@ static int llama_decode_internal(
|
|||
if (kv_self.head >= kv_self.size) {
|
||||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
if (hybrid) {
|
||||
kv_hybrid.head += n_tokens;
|
||||
if (kv_hybrid.head >= kv_hybrid.size) {
|
||||
kv_hybrid.head = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// plot the computation graph in dot format (for debugging purposes)
|
||||
|
@ -18366,7 +18389,7 @@ static int llama_decode_internal(
|
|||
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||
//llama_synchronize(&lctx);
|
||||
|
||||
// decide if we need to defrag the kv cache
|
||||
// decide if we need to defrag the kv cache(s)
|
||||
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
|
||||
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
|
||||
|
||||
|
@ -18376,6 +18399,13 @@ static int llama_decode_internal(
|
|||
|
||||
llama_kv_cache_defrag(kv_self);
|
||||
}
|
||||
|
||||
if (hybrid) {
|
||||
const float fragmentation = kv_hybrid.n >= 128 ? 1.0f - float(kv_hybrid.used)/float(kv_hybrid.n) : 0.0f;
|
||||
if (fragmentation > cparams.defrag_thold) {
|
||||
llama_kv_cache_defrag(kv_hybrid);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue