llama: restore a kv_cache in case of failed computation
This commit is contained in:
parent
acb9528362
commit
0026c810d7
1 changed files with 65 additions and 12 deletions
|
@ -2811,6 +2811,22 @@ struct llama_kv_cache {
|
|||
}
|
||||
};
|
||||
|
||||
// saves the kv_cache state for future recovery
|
||||
// used to preserve the kv_cache state before searching for a slot
|
||||
struct llama_kv_slot_restorer {
|
||||
struct llama_kv_cache_state {
|
||||
uint32_t head = 0;
|
||||
uint32_t size = 0;
|
||||
uint32_t used = 0;
|
||||
uint32_t n = 0;
|
||||
} old_state;
|
||||
|
||||
std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
|
||||
std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
|
||||
|
||||
bool restore = false;
|
||||
};
|
||||
|
||||
struct llama_control_vector {
|
||||
std::vector<struct ggml_tensor *> tensors; // per layer
|
||||
std::vector<ggml_context_ptr> ctxs;
|
||||
|
@ -3508,11 +3524,19 @@ static bool llama_kv_cache_init(
|
|||
// to the first cell of the slot.
|
||||
static bool llama_kv_cache_find_slot(
|
||||
struct llama_kv_cache & cache,
|
||||
const struct llama_ubatch & batch) {
|
||||
const struct llama_ubatch & batch,
|
||||
struct llama_kv_slot_restorer * slot_restorer = nullptr) {
|
||||
const uint32_t n_tokens = batch.n_tokens;
|
||||
const uint32_t n_seqs = batch.n_seqs;
|
||||
const uint32_t n_seq_tokens = batch.n_seq_tokens;
|
||||
|
||||
if (slot_restorer != nullptr) {
|
||||
slot_restorer->old_state.head = cache.head;
|
||||
slot_restorer->old_state.size = cache.size;
|
||||
slot_restorer->old_state.used = cache.used;
|
||||
slot_restorer->old_state.n = cache.n;
|
||||
}
|
||||
|
||||
if (cache.recurrent) {
|
||||
// For recurrent state architectures (like Mamba or RWKV),
|
||||
// each cache cell can store the state for a whole sequence.
|
||||
|
@ -3521,6 +3545,11 @@ static bool llama_kv_cache_find_slot(
|
|||
// can only process batches with an equal number of new tokens in each sequence
|
||||
GGML_ASSERT(batch.equal_seqs);
|
||||
|
||||
if (slot_restorer != nullptr) {
|
||||
slot_restorer->recurrent_cells = cache.cells;
|
||||
slot_restorer->restore = true;
|
||||
}
|
||||
|
||||
int32_t min = cache.size - 1;
|
||||
int32_t max = 0;
|
||||
|
||||
|
@ -3709,6 +3738,11 @@ static bool llama_kv_cache_find_slot(
|
|||
}
|
||||
}
|
||||
|
||||
if (slot_restorer != nullptr) {
|
||||
slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
|
||||
slot_restorer->restore = true;
|
||||
}
|
||||
|
||||
for (uint32_t s = 0; s < n_seqs; s++) {
|
||||
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
||||
uint32_t k = s*n_seq_tokens + i;
|
||||
|
@ -3998,6 +4032,23 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
|
|||
return cparams.flash_attn ? 256u : 32u;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_slot_restore(
|
||||
const struct llama_kv_slot_restorer & restorer,
|
||||
struct llama_kv_cache & cache) {
|
||||
if (restorer.restore) {
|
||||
cache.head = restorer.old_state.head;
|
||||
cache.size = restorer.old_state.size;
|
||||
cache.used = restorer.old_state.used;
|
||||
cache.n = restorer.old_state.n;
|
||||
|
||||
if (cache.recurrent) {
|
||||
cache.cells = restorer.recurrent_cells;
|
||||
} else {
|
||||
llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
//
|
||||
// model loading and saving
|
||||
//
|
||||
|
@ -17256,6 +17307,7 @@ static int llama_decode_internal(
|
|||
lctx.n_queued_tokens += n_tokens_all;
|
||||
|
||||
auto & kv_self = lctx.kv_self;
|
||||
llama_kv_slot_restorer kv_slot_restorer;
|
||||
|
||||
const int64_t n_embd = hparams.n_embd;
|
||||
const int64_t n_vocab = hparams.n_vocab;
|
||||
|
@ -17340,7 +17392,7 @@ static int llama_decode_internal(
|
|||
kv_self.head = 0;
|
||||
}
|
||||
|
||||
if (!llama_kv_cache_find_slot(kv_self, ubatch)) {
|
||||
if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
|
@ -17390,16 +17442,17 @@ static int llama_decode_internal(
|
|||
llama_set_inputs(lctx, ubatch);
|
||||
|
||||
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_SUCCESS:
|
||||
break;
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED:
|
||||
return -2;
|
||||
case GGML_STATUS_FAILED:
|
||||
default:
|
||||
return -3;
|
||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||
llama_kv_cache_slot_restore(kv_slot_restorer, kv_self);
|
||||
switch (compute_status) {
|
||||
case GGML_STATUS_ABORTED:
|
||||
return 2;
|
||||
case GGML_STATUS_ALLOC_FAILED:
|
||||
return -2;
|
||||
case GGML_STATUS_FAILED:
|
||||
default:
|
||||
return -3;
|
||||
}
|
||||
}
|
||||
|
||||
// update the kv ring buffer
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue