diff --git a/src/llama.cpp b/src/llama.cpp index c9e11e688..85e613a63 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -2811,6 +2811,22 @@ struct llama_kv_cache { } }; +// saves the kv_cache state for future recovery +// used to preserve the kv_cache state before searching for a slot +struct llama_kv_slot_restorer { + struct llama_kv_cache_state { + uint32_t head = 0; + uint32_t size = 0; + uint32_t used = 0; + uint32_t n = 0; + } old_state; + + std::vector recurrent_cells; // for recurrent models only + std::pair slot_boundaries; // for non-recurrent models only + + bool restore = false; +}; + struct llama_control_vector { std::vector tensors; // per layer std::vector ctxs; @@ -3508,11 +3524,19 @@ static bool llama_kv_cache_init( // to the first cell of the slot. static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, - const struct llama_ubatch & batch) { + const struct llama_ubatch & batch, + struct llama_kv_slot_restorer * slot_restorer = nullptr) { const uint32_t n_tokens = batch.n_tokens; const uint32_t n_seqs = batch.n_seqs; const uint32_t n_seq_tokens = batch.n_seq_tokens; + if (slot_restorer != nullptr) { + slot_restorer->old_state.head = cache.head; + slot_restorer->old_state.size = cache.size; + slot_restorer->old_state.used = cache.used; + slot_restorer->old_state.n = cache.n; + } + if (cache.recurrent) { // For recurrent state architectures (like Mamba or RWKV), // each cache cell can store the state for a whole sequence. @@ -3521,6 +3545,11 @@ static bool llama_kv_cache_find_slot( // can only process batches with an equal number of new tokens in each sequence GGML_ASSERT(batch.equal_seqs); + if (slot_restorer != nullptr) { + slot_restorer->recurrent_cells = cache.cells; + slot_restorer->restore = true; + } + int32_t min = cache.size - 1; int32_t max = 0; @@ -3709,6 +3738,11 @@ static bool llama_kv_cache_find_slot( } } + if (slot_restorer != nullptr) { + slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens); + slot_restorer->restore = true; + } + for (uint32_t s = 0; s < n_seqs; s++) { for (uint32_t i = 0; i < n_seq_tokens; ++i) { uint32_t k = s*n_seq_tokens + i; @@ -3998,6 +4032,23 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams) return cparams.flash_attn ? 256u : 32u; } +static void llama_kv_cache_slot_restore( + const struct llama_kv_slot_restorer & restorer, + struct llama_kv_cache & cache) { + if (restorer.restore) { + cache.head = restorer.old_state.head; + cache.size = restorer.old_state.size; + cache.used = restorer.old_state.used; + cache.n = restorer.old_state.n; + + if (cache.recurrent) { + cache.cells = restorer.recurrent_cells; + } else { + llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1); + } + } +} + // // model loading and saving // @@ -17256,6 +17307,7 @@ static int llama_decode_internal( lctx.n_queued_tokens += n_tokens_all; auto & kv_self = lctx.kv_self; + llama_kv_slot_restorer kv_slot_restorer; const int64_t n_embd = hparams.n_embd; const int64_t n_vocab = hparams.n_vocab; @@ -17340,7 +17392,7 @@ static int llama_decode_internal( kv_self.head = 0; } - if (!llama_kv_cache_find_slot(kv_self, ubatch)) { + if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) { return 1; } @@ -17390,16 +17442,17 @@ static int llama_decode_internal( llama_set_inputs(lctx, ubatch); const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool); - switch (compute_status) { - case GGML_STATUS_SUCCESS: - break; - case GGML_STATUS_ABORTED: - return 2; - case GGML_STATUS_ALLOC_FAILED: - return -2; - case GGML_STATUS_FAILED: - default: - return -3; + if (compute_status != GGML_STATUS_SUCCESS) { + llama_kv_cache_slot_restore(kv_slot_restorer, kv_self); + switch (compute_status) { + case GGML_STATUS_ABORTED: + return 2; + case GGML_STATUS_ALLOC_FAILED: + return -2; + case GGML_STATUS_FAILED: + default: + return -3; + } } // update the kv ring buffer