llama : defragment via non-overlapping moves
This commit is contained in:
parent
2d7203b975
commit
65323bc770
1 changed files with 58 additions and 13 deletions
71
llama.cpp
71
llama.cpp
|
@ -8028,7 +8028,7 @@ static int llama_decode_internal(
|
||||||
}
|
}
|
||||||
|
|
||||||
// copy the KV cache to the host memory and reshuffle the cells to the beginning of the cache
|
// copy the KV cache to the host memory and reshuffle the cells to the beginning of the cache
|
||||||
// this way we eliminate any empty segments that may have been left by previous KV cache operations
|
// this way we eliminate any empty holes that may have been left by previous KV cache operations
|
||||||
//
|
//
|
||||||
// TODO: optimizations are possible:
|
// TODO: optimizations are possible:
|
||||||
// - multiple threads
|
// - multiple threads
|
||||||
|
@ -8045,36 +8045,81 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
|
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
|
||||||
|
const uint32_t n_used = kv_self.used;
|
||||||
|
|
||||||
const uint32_t kv_size = kv_self.size;
|
const uint32_t kv_size = kv_self.size;
|
||||||
|
|
||||||
|
assert(n_used <= n_kv);
|
||||||
|
|
||||||
const int64_t t_start = ggml_time_us();
|
const int64_t t_start = ggml_time_us();
|
||||||
|
|
||||||
std::vector<uint8_t> buf_k;
|
std::vector<uint8_t> buf_k;
|
||||||
std::vector<uint8_t> buf_v;
|
std::vector<uint8_t> buf_v;
|
||||||
|
|
||||||
// the destination cell in the new KV cache
|
|
||||||
uint32_t id = 0;
|
|
||||||
|
|
||||||
// number of cells moved
|
// number of cells moved
|
||||||
uint32_t n_moves = 0;
|
uint32_t n_moves = 0;
|
||||||
|
|
||||||
// determine which KV cells to move where
|
// determine which KV cells to move where
|
||||||
std::vector<uint32_t> ids(n_kv, n_kv);
|
std::vector<uint32_t> ids(n_kv, n_kv);
|
||||||
|
|
||||||
for (uint32_t i0 = 0; i0 < n_kv; ++i0) {
|
for (uint32_t i0 = 0; i0 < n_used; ++i0) {
|
||||||
const auto & cell0 = kv_self.cells[i0];
|
const auto & cell0 = kv_self.cells[i0];
|
||||||
|
|
||||||
if (!cell0.is_empty()) {
|
if (!cell0.is_empty()) {
|
||||||
ids[i0] = id;
|
ids[i0] = i0;
|
||||||
|
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// found a hole - fill it with data from the end of the cache
|
||||||
|
|
||||||
|
// determine the size of the hole
|
||||||
|
uint32_t nh = 1;
|
||||||
|
while (i0 + nh < n_used && kv_self.cells[i0 + nh].is_empty()) {
|
||||||
|
nh++;
|
||||||
|
}
|
||||||
|
|
||||||
|
// starting from the end, find nh non-empty cells
|
||||||
|
uint32_t nf = 0;
|
||||||
|
uint32_t is = n_kv - 1;
|
||||||
|
for (; is > i0; --is) {
|
||||||
|
const auto & cell1 = kv_self.cells[is];
|
||||||
|
|
||||||
|
if (cell1.is_empty() || ids[is] != n_kv) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// non-empty cell which is not yet moved
|
||||||
|
nf++;
|
||||||
|
if (nf == nh) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(nf == nh && "KV defrag bug: nf != nh");
|
||||||
|
|
||||||
|
nf = 0;
|
||||||
|
|
||||||
|
// go back and move the nf cells to the hole
|
||||||
|
for (uint32_t i1 = is; i1 < n_kv; ++i1) {
|
||||||
|
const auto & cell1 = kv_self.cells[i1];
|
||||||
|
|
||||||
|
if (cell1.is_empty() || ids[i1] != n_kv) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
ids[i1] = i0 + nf;
|
||||||
|
|
||||||
|
// move the cell meta data
|
||||||
|
kv_self.cells[i0 + nf] = cell1;
|
||||||
|
|
||||||
if (i0 != id) {
|
|
||||||
kv_self.cells[id] = cell0;
|
|
||||||
n_moves++;
|
n_moves++;
|
||||||
|
nf++;
|
||||||
}
|
}
|
||||||
|
|
||||||
id++;
|
LLAMA_LOG_INFO("(tmp log) KV defrag: move [%u, %u) to [%u, %u)\n", is, n_kv, i0, i0 + nh);
|
||||||
}
|
|
||||||
|
i0 += nh - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (n_moves == 0) {
|
if (n_moves == 0) {
|
||||||
|
@ -8083,11 +8128,11 @@ static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||||
|
|
||||||
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
||||||
|
|
||||||
kv_self.head = id;
|
kv_self.head = n_used;
|
||||||
kv_self.used = id;
|
kv_self.used = n_used;
|
||||||
|
|
||||||
// zero the rest of the cells
|
// zero the rest of the cells
|
||||||
for (uint32_t i = id; i < n_kv; ++i) {
|
for (uint32_t i = n_used; i < n_kv; ++i) {
|
||||||
kv_self.cells[i] = llama_kv_cell();
|
kv_self.cells[i] = llama_kv_cell();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue