feat: Update the logic in llama_decode_internal for kv_hybrid cache
Branch: BambaArchitecture Signed-off-by: Gabe Goodhart <ghart@us.ibm.com>
This commit is contained in:
parent
44bf431ab4
commit
4543ed5640
1 changed files with 32 additions and 2 deletions
|
@ -18104,6 +18104,11 @@ static int llama_decode_internal(
|
||||||
auto & kv_self = lctx.kv_self;
|
auto & kv_self = lctx.kv_self;
|
||||||
llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
||||||
|
|
||||||
|
// Only used for hybrid-recurrent models (e.g. Bamba)
|
||||||
|
const bool hybrid = llama_model_is_hybrid(&model);
|
||||||
|
auto & kv_hybrid = lctx.kv_hybrid;
|
||||||
|
llama_kv_slot_restorer kv_slot_restorer_hybrid(kv_hybrid);
|
||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_vocab = hparams.n_vocab;
|
const int64_t n_vocab = hparams.n_vocab;
|
||||||
|
|
||||||
|
@ -18192,7 +18197,15 @@ static int llama_decode_internal(
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
kv_slot_restorer.save(slot);
|
kv_slot_restorer.save(slot);
|
||||||
|
if (hybrid) {
|
||||||
|
const auto slot_hybrid = llama_kv_cache_find_slot(kv_hybrid, ubatch);
|
||||||
|
if (!slot_hybrid) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
kv_slot_restorer_hybrid.save(slot_hybrid);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Update this clause for hybrid recurrent models
|
||||||
if (!kv_self.recurrent) {
|
if (!kv_self.recurrent) {
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
|
@ -18241,6 +18254,9 @@ static int llama_decode_internal(
|
||||||
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
||||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||||
kv_slot_restorer.restore(kv_self);
|
kv_slot_restorer.restore(kv_self);
|
||||||
|
if (hybrid) {
|
||||||
|
kv_slot_restorer_hybrid.restore(kv_hybrid);
|
||||||
|
}
|
||||||
switch (compute_status) {
|
switch (compute_status) {
|
||||||
case GGML_STATUS_ABORTED:
|
case GGML_STATUS_ABORTED:
|
||||||
return 2;
|
return 2;
|
||||||
|
@ -18252,7 +18268,7 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// update the kv ring buffer
|
// update the kv ring buffer(s)
|
||||||
{
|
{
|
||||||
kv_self.head += n_tokens;
|
kv_self.head += n_tokens;
|
||||||
|
|
||||||
|
@ -18260,6 +18276,13 @@ static int llama_decode_internal(
|
||||||
if (kv_self.head >= kv_self.size) {
|
if (kv_self.head >= kv_self.size) {
|
||||||
kv_self.head = 0;
|
kv_self.head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (hybrid) {
|
||||||
|
kv_hybrid.head += n_tokens;
|
||||||
|
if (kv_hybrid.head >= kv_hybrid.size) {
|
||||||
|
kv_hybrid.head = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// plot the computation graph in dot format (for debugging purposes)
|
// plot the computation graph in dot format (for debugging purposes)
|
||||||
|
@ -18366,7 +18389,7 @@ static int llama_decode_internal(
|
||||||
// wait for the computation to finish (automatically done when obtaining the model output)
|
// wait for the computation to finish (automatically done when obtaining the model output)
|
||||||
//llama_synchronize(&lctx);
|
//llama_synchronize(&lctx);
|
||||||
|
|
||||||
// decide if we need to defrag the kv cache
|
// decide if we need to defrag the kv cache(s)
|
||||||
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
|
if (cparams.causal_attn && cparams.defrag_thold >= 0.0f) {
|
||||||
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
|
const float fragmentation = kv_self.n >= 128 ? 1.0f - float(kv_self.used)/float(kv_self.n) : 0.0f;
|
||||||
|
|
||||||
|
@ -18376,6 +18399,13 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
llama_kv_cache_defrag(kv_self);
|
llama_kv_cache_defrag(kv_self);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (hybrid) {
|
||||||
|
const float fragmentation = kv_hybrid.n >= 128 ? 1.0f - float(kv_hybrid.used)/float(kv_hybrid.n) : 0.0f;
|
||||||
|
if (fragmentation > cparams.defrag_thold) {
|
||||||
|
llama_kv_cache_defrag(kv_hybrid);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
// Reset state for the next token before backend sync, to allow the CPU activities in the reset to
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue