llama: correct reverting of the entire batch.
also updates `llama_kv_cache_find_slot`, will correctly count the number of `used` cells for recurrent models
This commit is contained in:
parent
0026c810d7
commit
ee599f901a
1 changed files with 63 additions and 57 deletions
120
src/llama.cpp
120
src/llama.cpp
|
@ -2811,22 +2811,6 @@ struct llama_kv_cache {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// saves the kv_cache state for future recovery
|
|
||||||
// used to preserve the kv_cache state before searching for a slot
|
|
||||||
struct llama_kv_slot_restorer {
|
|
||||||
struct llama_kv_cache_state {
|
|
||||||
uint32_t head = 0;
|
|
||||||
uint32_t size = 0;
|
|
||||||
uint32_t used = 0;
|
|
||||||
uint32_t n = 0;
|
|
||||||
} old_state;
|
|
||||||
|
|
||||||
std::vector<llama_kv_cell> recurrent_cells; // for recurrent models only
|
|
||||||
std::pair<uint32_t, uint32_t> slot_boundaries; // for non-recurrent models only
|
|
||||||
|
|
||||||
bool restore = false;
|
|
||||||
};
|
|
||||||
|
|
||||||
struct llama_control_vector {
|
struct llama_control_vector {
|
||||||
std::vector<struct ggml_tensor *> tensors; // per layer
|
std::vector<struct ggml_tensor *> tensors; // per layer
|
||||||
std::vector<ggml_context_ptr> ctxs;
|
std::vector<ggml_context_ptr> ctxs;
|
||||||
|
@ -3522,21 +3506,24 @@ static bool llama_kv_cache_init(
|
||||||
// updates the cache head
|
// updates the cache head
|
||||||
// Note: On success, it's important that cache.head points
|
// Note: On success, it's important that cache.head points
|
||||||
// to the first cell of the slot.
|
// to the first cell of the slot.
|
||||||
static bool llama_kv_cache_find_slot(
|
struct llama_kv_cache_slot_info {
|
||||||
|
std::pair<uint32_t, uint32_t> boundaries;
|
||||||
|
bool found = false;
|
||||||
|
|
||||||
|
explicit llama_kv_cache_slot_info(bool found_) : found{found_} {}
|
||||||
|
llama_kv_cache_slot_info(uint32_t begin, uint32_t end) : boundaries{begin, end}, found{true} {}
|
||||||
|
|
||||||
|
operator bool() const { return found; }
|
||||||
|
};
|
||||||
|
static const llama_kv_cache_slot_info llama_kv_cache_slot_info_failed{false};
|
||||||
|
|
||||||
|
static struct llama_kv_cache_slot_info llama_kv_cache_find_slot(
|
||||||
struct llama_kv_cache & cache,
|
struct llama_kv_cache & cache,
|
||||||
const struct llama_ubatch & batch,
|
const struct llama_ubatch & batch) {
|
||||||
struct llama_kv_slot_restorer * slot_restorer = nullptr) {
|
|
||||||
const uint32_t n_tokens = batch.n_tokens;
|
const uint32_t n_tokens = batch.n_tokens;
|
||||||
const uint32_t n_seqs = batch.n_seqs;
|
const uint32_t n_seqs = batch.n_seqs;
|
||||||
const uint32_t n_seq_tokens = batch.n_seq_tokens;
|
const uint32_t n_seq_tokens = batch.n_seq_tokens;
|
||||||
|
|
||||||
if (slot_restorer != nullptr) {
|
|
||||||
slot_restorer->old_state.head = cache.head;
|
|
||||||
slot_restorer->old_state.size = cache.size;
|
|
||||||
slot_restorer->old_state.used = cache.used;
|
|
||||||
slot_restorer->old_state.n = cache.n;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (cache.recurrent) {
|
if (cache.recurrent) {
|
||||||
// For recurrent state architectures (like Mamba or RWKV),
|
// For recurrent state architectures (like Mamba or RWKV),
|
||||||
// each cache cell can store the state for a whole sequence.
|
// each cache cell can store the state for a whole sequence.
|
||||||
|
@ -3545,11 +3532,6 @@ static bool llama_kv_cache_find_slot(
|
||||||
// can only process batches with an equal number of new tokens in each sequence
|
// can only process batches with an equal number of new tokens in each sequence
|
||||||
GGML_ASSERT(batch.equal_seqs);
|
GGML_ASSERT(batch.equal_seqs);
|
||||||
|
|
||||||
if (slot_restorer != nullptr) {
|
|
||||||
slot_restorer->recurrent_cells = cache.cells;
|
|
||||||
slot_restorer->restore = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
int32_t min = cache.size - 1;
|
int32_t min = cache.size - 1;
|
||||||
int32_t max = 0;
|
int32_t max = 0;
|
||||||
|
|
||||||
|
@ -3563,7 +3545,7 @@ static bool llama_kv_cache_find_slot(
|
||||||
// too big seq_id
|
// too big seq_id
|
||||||
// TODO: would it be possible to resize the cache instead?
|
// TODO: would it be possible to resize the cache instead?
|
||||||
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
|
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, cache.size);
|
||||||
return false;
|
return llama_kv_cache_slot_info_failed;
|
||||||
}
|
}
|
||||||
if (j > 0) {
|
if (j > 0) {
|
||||||
llama_kv_cell & seq = cache.cells[seq_id];
|
llama_kv_cell & seq = cache.cells[seq_id];
|
||||||
|
@ -3698,15 +3680,17 @@ static bool llama_kv_cache_find_slot(
|
||||||
// allow getting the range of used cells, from head to head + n
|
// allow getting the range of used cells, from head to head + n
|
||||||
cache.head = min;
|
cache.head = min;
|
||||||
cache.n = max - min + 1;
|
cache.n = max - min + 1;
|
||||||
|
cache.used = std::count_if(cache.cells.begin(), cache.cells.end(),
|
||||||
|
[](const llama_kv_cell& cell){ return !cell.is_empty(); });
|
||||||
|
|
||||||
// sanity check
|
// sanity check
|
||||||
return cache.n >= n_seqs;
|
return llama_kv_cache_slot_info(cache.n >= n_seqs);
|
||||||
}
|
}
|
||||||
// otherwise, one cell per token.
|
// otherwise, one cell per token.
|
||||||
|
|
||||||
if (n_tokens > cache.size) {
|
if (n_tokens > cache.size) {
|
||||||
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
|
LLAMA_LOG_ERROR("%s: n_tokens=%d > cache.size=%d\n", __func__, n_tokens, cache.size);
|
||||||
return false;
|
return llama_kv_cache_slot_info_failed;
|
||||||
}
|
}
|
||||||
|
|
||||||
uint32_t n_tested = 0;
|
uint32_t n_tested = 0;
|
||||||
|
@ -3734,15 +3718,10 @@ static bool llama_kv_cache_find_slot(
|
||||||
|
|
||||||
if (n_tested >= cache.size) {
|
if (n_tested >= cache.size) {
|
||||||
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
//LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens);
|
||||||
return false;
|
return llama_kv_cache_slot_info_failed;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot_restorer != nullptr) {
|
|
||||||
slot_restorer->slot_boundaries = std::make_pair(cache.head, cache.head + n_tokens);
|
|
||||||
slot_restorer->restore = true;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (uint32_t s = 0; s < n_seqs; s++) {
|
for (uint32_t s = 0; s < n_seqs; s++) {
|
||||||
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
for (uint32_t i = 0; i < n_seq_tokens; ++i) {
|
||||||
uint32_t k = s*n_seq_tokens + i;
|
uint32_t k = s*n_seq_tokens + i;
|
||||||
|
@ -3756,7 +3735,7 @@ static bool llama_kv_cache_find_slot(
|
||||||
|
|
||||||
cache.used += n_tokens;
|
cache.used += n_tokens;
|
||||||
|
|
||||||
return true;
|
return llama_kv_cache_slot_info(cache.head, cache.head + n_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
// find how many cells are currently in use
|
// find how many cells are currently in use
|
||||||
|
@ -4032,22 +4011,47 @@ static uint32_t llama_kv_cache_get_padding(const struct llama_cparams & cparams)
|
||||||
return cparams.flash_attn ? 256u : 32u;
|
return cparams.flash_attn ? 256u : 32u;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_kv_cache_slot_restore(
|
// saves the kv_cache state for future recovery.
|
||||||
const struct llama_kv_slot_restorer & restorer,
|
// used to rollback llama_kv_cache_find_slot changes.
|
||||||
struct llama_kv_cache & cache) {
|
struct llama_kv_slot_restorer {
|
||||||
if (restorer.restore) {
|
struct llama_kv_cache_state {
|
||||||
cache.head = restorer.old_state.head;
|
uint32_t head = 0;
|
||||||
cache.size = restorer.old_state.size;
|
uint32_t n = 0;
|
||||||
cache.used = restorer.old_state.used;
|
} old_state;
|
||||||
cache.n = restorer.old_state.n;
|
|
||||||
|
|
||||||
if (cache.recurrent) {
|
std::vector<std::pair<uint32_t, uint32_t>> slot_boundaries; // for non-recurrent models only
|
||||||
cache.cells = restorer.recurrent_cells;
|
|
||||||
} else {
|
bool do_restore = false;
|
||||||
llama_kv_cache_seq_rm(cache, -1, restorer.slot_boundaries.first, restorer.slot_boundaries.second + 1);
|
|
||||||
|
explicit llama_kv_slot_restorer(const struct llama_kv_cache & cache) {
|
||||||
|
old_state.head = cache.head;
|
||||||
|
old_state.n = cache.n;
|
||||||
|
}
|
||||||
|
|
||||||
|
void save(const struct llama_kv_cache_slot_info& slot) {
|
||||||
|
if (slot) {
|
||||||
|
do_restore = true;
|
||||||
|
if (slot.boundaries.first != slot.boundaries.second) {
|
||||||
|
slot_boundaries.push_back(slot.boundaries);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
void restore(struct llama_kv_cache & cache) {
|
||||||
|
if (do_restore) {
|
||||||
|
cache.head = old_state.head;
|
||||||
|
cache.n = old_state.n;
|
||||||
|
|
||||||
|
if (cache.recurrent) { // recurrent models like Mamba or RWKV can't have a state partially erased
|
||||||
|
llama_kv_cache_seq_rm(cache, -1, -1, -1);
|
||||||
|
} else {
|
||||||
|
for (auto & slot : slot_boundaries) {
|
||||||
|
llama_kv_cache_seq_rm(cache, -1, slot.first, slot.second);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
// model loading and saving
|
// model loading and saving
|
||||||
|
@ -17307,7 +17311,7 @@ static int llama_decode_internal(
|
||||||
lctx.n_queued_tokens += n_tokens_all;
|
lctx.n_queued_tokens += n_tokens_all;
|
||||||
|
|
||||||
auto & kv_self = lctx.kv_self;
|
auto & kv_self = lctx.kv_self;
|
||||||
llama_kv_slot_restorer kv_slot_restorer;
|
llama_kv_slot_restorer kv_slot_restorer(kv_self);
|
||||||
|
|
||||||
const int64_t n_embd = hparams.n_embd;
|
const int64_t n_embd = hparams.n_embd;
|
||||||
const int64_t n_vocab = hparams.n_vocab;
|
const int64_t n_vocab = hparams.n_vocab;
|
||||||
|
@ -17392,9 +17396,11 @@ static int llama_decode_internal(
|
||||||
kv_self.head = 0;
|
kv_self.head = 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!llama_kv_cache_find_slot(kv_self, ubatch, &kv_slot_restorer)) {
|
const auto slot = llama_kv_cache_find_slot(kv_self, ubatch);
|
||||||
|
if (!slot) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
kv_slot_restorer.save(slot);
|
||||||
|
|
||||||
if (!kv_self.recurrent) {
|
if (!kv_self.recurrent) {
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
|
@ -17443,7 +17449,7 @@ static int llama_decode_internal(
|
||||||
|
|
||||||
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
const auto compute_status = llama_graph_compute(lctx, gf, n_threads, threadpool);
|
||||||
if (compute_status != GGML_STATUS_SUCCESS) {
|
if (compute_status != GGML_STATUS_SUCCESS) {
|
||||||
llama_kv_cache_slot_restore(kv_slot_restorer, kv_self);
|
kv_slot_restorer.restore(kv_self);
|
||||||
switch (compute_status) {
|
switch (compute_status) {
|
||||||
case GGML_STATUS_ABORTED:
|
case GGML_STATUS_ABORTED:
|
||||||
return 2;
|
return 2;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue