llama : add llama_kv_cache_defrag
This commit is contained in:
parent
9ec749df59
commit
65f21ec5d3
3 changed files with 383 additions and 231 deletions
|
@ -183,6 +183,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
llama_kv_cache_defrag (ctx);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||
|
@ -213,6 +214,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_kv_cache_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
|
||||
llama_kv_cache_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
|
||||
llama_kv_cache_defrag (ctx);
|
||||
llama_kv_cache_update (ctx);
|
||||
|
||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||
|
|
601
llama.cpp
601
llama.cpp
|
@ -1722,6 +1722,7 @@ struct llama_kv_cell {
|
|||
// ring-buffer of cached KV data
|
||||
struct llama_kv_cache {
|
||||
bool has_shift = false;
|
||||
bool do_defrag = false;
|
||||
|
||||
// Note: The value of head isn't only used to optimize searching
|
||||
// for a free KV slot. llama_decode_internal also uses it, so it
|
||||
|
@ -2278,6 +2279,10 @@ static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos del
|
|||
cache.compress_delta = delta;
|
||||
}
|
||||
|
||||
static void llama_kv_cache_defrag(struct llama_kv_cache & cache) {
|
||||
cache.do_defrag = true;
|
||||
}
|
||||
|
||||
//
|
||||
// model loading and saving
|
||||
//
|
||||
|
@ -8029,6 +8034,359 @@ static int llama_decode_internal(
|
|||
return 0;
|
||||
}
|
||||
|
||||
// summary:
|
||||
//
|
||||
// - determine which KV cell pairs (i0, i1) to merge:
|
||||
//
|
||||
// abs(cell[i0].pos - cell[i1].pos) <= compress_delta
|
||||
//
|
||||
// - move the KV cache to the Host memory for easier maniiplation
|
||||
// - processing is done layer-by-layer
|
||||
// - convert the KV data to F32
|
||||
// - merge the KV data (different ways to merge)
|
||||
// - convert the KV data back to the original type
|
||||
// - move the KV cache back to the device memory
|
||||
// - update the KV cache metadata
|
||||
//
|
||||
// as a side effect, the new KV cache is defragmented
|
||||
//
|
||||
static void llama_kv_cache_compress_internal(struct llama_context & lctx) {
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k);
|
||||
const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v);
|
||||
const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv);
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
std::vector<uint8_t> buf_q;
|
||||
|
||||
std::vector<float> buf_src_f32;
|
||||
std::vector<float> buf_dst_f32;
|
||||
|
||||
struct c_pair { uint32_t i0, i1; };
|
||||
struct c_info { bool merged; uint32_t id, cnt, r; };
|
||||
|
||||
std::vector<c_info> infos(kv_size, { false, 0, 0, 0 });
|
||||
|
||||
// the destination cell in the new KV cache
|
||||
uint32_t id = 0;
|
||||
|
||||
// number of pairs merged
|
||||
uint32_t n_merges = 0;
|
||||
|
||||
// determine which KV cells to merge
|
||||
for (uint32_t i0 = 0; i0 < kv_size; ++i0) {
|
||||
const auto & cell0 = kv_self.cells[i0];
|
||||
|
||||
if (!cell0.is_empty() && !infos[i0].merged) {
|
||||
infos[i0] = { true, id, 0, 0 };
|
||||
infos[id].cnt = 1;
|
||||
|
||||
const llama_pos p0 = cell0.pos;
|
||||
|
||||
for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) {
|
||||
const auto & cell1 = kv_self.cells[i1];
|
||||
|
||||
if (i0 != i1 && cell0.is_same_seq(cell1)) {
|
||||
const llama_pos p1 = cell1.pos;
|
||||
|
||||
if (std::abs(p0 - p1) <= kv_self.compress_delta) {
|
||||
infos[i1] = { true, id, 0, 0 };
|
||||
infos[id].cnt++;
|
||||
n_merges++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (i0 != id) {
|
||||
kv_self.cells[id] = cell0;
|
||||
}
|
||||
|
||||
id++;
|
||||
}
|
||||
}
|
||||
|
||||
kv_self.head = id;
|
||||
kv_self.used = id;
|
||||
|
||||
for (uint32_t i = id; i < kv_size; ++i) {
|
||||
kv_self.cells[i] = llama_kv_cell();
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges);
|
||||
|
||||
ggml_type_traits_t tt_k;
|
||||
ggml_type_traits_t tt_v;
|
||||
|
||||
tt_k = ggml_internal_get_type_traits(kv_self.type_k);
|
||||
tt_v = ggml_internal_get_type_traits(kv_self.type_v);
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
infos[i].r = 0;
|
||||
}
|
||||
|
||||
// update keys
|
||||
{
|
||||
const int64_t ne = n_embd_k_gqa*kv_size;
|
||||
|
||||
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne);
|
||||
|
||||
buf_q.resize(k_size);
|
||||
|
||||
buf_src_f32.resize(ne);
|
||||
buf_dst_f32.resize(ne);
|
||||
|
||||
ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
|
||||
|
||||
tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne);
|
||||
|
||||
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
if (!infos[i].merged) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t id = infos[i].id;
|
||||
|
||||
#if 1
|
||||
// merge using averaging
|
||||
{
|
||||
const float scale = 1.0f/float(infos[id].cnt);
|
||||
|
||||
const int64_t os = i*n_embd_k_gqa;
|
||||
const int64_t od = id*n_embd_k_gqa;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_k_gqa; ++j) {
|
||||
buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
|
||||
}
|
||||
}
|
||||
#else
|
||||
// merge separate heads
|
||||
{
|
||||
for (uint32_t h = 0; h < n_head_kv; ++h) {
|
||||
if ((h + il) % infos[id].cnt != infos[id].r) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k;
|
||||
const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_head_k; ++j) {
|
||||
buf_dst_f32[od + j] = buf_src_f32[os + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
infos[id].r++;
|
||||
#endif
|
||||
}
|
||||
|
||||
tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
||||
|
||||
ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
infos[i].r = 0;
|
||||
}
|
||||
|
||||
// update values (note: they are transposed)
|
||||
{
|
||||
const int64_t ne = n_embd_v_gqa*kv_size;
|
||||
|
||||
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne);
|
||||
|
||||
buf_q.resize(v_size);
|
||||
|
||||
buf_src_f32.resize(ne);
|
||||
buf_dst_f32.resize(ne);
|
||||
|
||||
ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
|
||||
|
||||
tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne);
|
||||
|
||||
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
if (!infos[i].merged) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t id = infos[i].id;
|
||||
|
||||
#if 1
|
||||
// merge using averaging
|
||||
{
|
||||
const float scale = 1.0f/float(infos[id].cnt);
|
||||
//printf("i: %d -> id: %d, scale: %f\n", i, id, scale);
|
||||
|
||||
const int64_t os = i;
|
||||
const int64_t od = id;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
|
||||
}
|
||||
}
|
||||
#else
|
||||
// merge separate heads
|
||||
{
|
||||
for (uint32_t h = 0; h < n_head_kv; ++h) {
|
||||
if ((h + il) % infos[id].cnt != infos[id].r) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t os = i;
|
||||
const int64_t od = id;
|
||||
|
||||
for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) {
|
||||
buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
infos[id].r++;
|
||||
#endif
|
||||
}
|
||||
|
||||
tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
||||
|
||||
ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||
}
|
||||
|
||||
// copy the KV cache to the host memory and reshuffle the cells to the beginning of the cache
|
||||
// removing any empty segments that may have been left by previous KV cache operations
|
||||
// TODO: optimizations are possible:
|
||||
// - multiple threads
|
||||
// - avoid copying to the host memory when already there
|
||||
// TODO: can we do all this on-device?
|
||||
static void llama_kv_cache_defrag_internal(struct llama_context & lctx) {
|
||||
auto & kv_self = lctx.kv_self;
|
||||
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const uint32_t n_kv = llama_kv_cache_cell_max(kv_self);
|
||||
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
std::vector<uint8_t> buf_k;
|
||||
std::vector<uint8_t> buf_v;
|
||||
|
||||
// the destination cell in the new KV cache
|
||||
uint32_t id = 0;
|
||||
|
||||
// number of cells moved
|
||||
uint32_t n_moves = 0;
|
||||
|
||||
// determine which KV cells to move where
|
||||
std::vector<uint32_t> ids(n_kv, n_kv);
|
||||
|
||||
for (uint32_t i0 = 0; i0 < n_kv; ++i0) {
|
||||
const auto & cell0 = kv_self.cells[i0];
|
||||
|
||||
if (!cell0.is_empty()) {
|
||||
ids[i0] = id;
|
||||
|
||||
if (i0 != id) {
|
||||
kv_self.cells[id] = cell0;
|
||||
n_moves++;
|
||||
}
|
||||
|
||||
id++;
|
||||
}
|
||||
}
|
||||
|
||||
if (n_moves == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV defrag cell moves: %u\n", n_moves);
|
||||
|
||||
kv_self.head = id;
|
||||
kv_self.used = id;
|
||||
|
||||
// zero the rest of the cells
|
||||
for (uint32_t i = id; i < n_kv; ++i) {
|
||||
kv_self.cells[i] = llama_kv_cell();
|
||||
}
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
|
||||
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_size);
|
||||
|
||||
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
||||
const size_t v_size = ggml_row_size (kv_self.v_l[il]->type, n_embd_v_gqa*kv_size);
|
||||
|
||||
buf_k.resize(k_size);
|
||||
buf_v.resize(v_size);
|
||||
|
||||
ggml_backend_tensor_get(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
|
||||
ggml_backend_tensor_get(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
|
||||
|
||||
// batch move [i, i+nm) to [id, id+nm)
|
||||
// note: cells can move only to a lower index
|
||||
for (uint32_t i = 0; i < n_kv; ++i) {
|
||||
const uint32_t id = ids[i];
|
||||
|
||||
if (i == id || id == n_kv) {
|
||||
continue;
|
||||
}
|
||||
|
||||
uint32_t nm = 1;
|
||||
|
||||
while (i + nm < n_kv && ids[i + nm] == id + nm) {
|
||||
nm++;
|
||||
}
|
||||
|
||||
// move keys
|
||||
{
|
||||
const int64_t os = i*k_size_row;
|
||||
const int64_t od = id*k_size_row;
|
||||
|
||||
memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row);
|
||||
}
|
||||
|
||||
// move values (note: they are transposed)
|
||||
{
|
||||
const int64_t os = i;
|
||||
const int64_t od = id;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el);
|
||||
}
|
||||
}
|
||||
|
||||
i += nm - 1;
|
||||
}
|
||||
|
||||
ggml_backend_tensor_set(kv_self.k_l[il], buf_k.data(), 0, buf_k.size());
|
||||
ggml_backend_tensor_set(kv_self.v_l[il], buf_v.data(), 0, buf_v.size());
|
||||
}
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV defrag time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||
}
|
||||
|
||||
static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||
// apply K-shift if needed
|
||||
if (lctx.model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE && lctx.kv_self.has_shift) {
|
||||
|
@ -8051,240 +8409,19 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
|||
}
|
||||
}
|
||||
|
||||
// compress the KV cache data if needed:
|
||||
//
|
||||
// - determine which KV cell pairs (i0, i1) to merge:
|
||||
//
|
||||
// abs(cell[i0].pos - cell[i1].pos) <= compress_delta
|
||||
//
|
||||
// - move the KV cache to the Host memory for easier maniiplation
|
||||
// - processing is done layer-by-layer
|
||||
// - convert the KV data to F32
|
||||
// - merge the KV data (different ways to merge)
|
||||
// - convert the KV data back to the original type
|
||||
// - move the KV cache back to the device memory
|
||||
// - update the KV cache metadata
|
||||
//
|
||||
// as a side effect, the new KV cache is defragmented
|
||||
//
|
||||
// compress the KV cache data if needed
|
||||
if (lctx.kv_self.compress_delta >= 0) {
|
||||
auto & kv_self = lctx.kv_self;
|
||||
llama_kv_cache_compress_internal(lctx);
|
||||
|
||||
const auto & hparams = lctx.model.hparams;
|
||||
lctx.kv_self.compress_delta = -1;
|
||||
lctx.kv_self.do_defrag = false;
|
||||
}
|
||||
|
||||
const uint32_t n_layer = hparams.n_layer;
|
||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||
const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k);
|
||||
const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v);
|
||||
const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv);
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
// defragment the KV cache if needed
|
||||
if (lctx.kv_self.do_defrag) {
|
||||
llama_kv_cache_defrag_internal(lctx);
|
||||
|
||||
std::vector<uint8_t> buf_q;
|
||||
|
||||
std::vector<float> buf_src_f32;
|
||||
std::vector<float> buf_dst_f32;
|
||||
|
||||
const int64_t t_start = ggml_time_us();
|
||||
|
||||
struct c_pair { uint32_t i0, i1; };
|
||||
struct c_info { bool merged; uint32_t id, cnt, r; };
|
||||
|
||||
std::vector<c_info> infos(kv_size, { false, 0, 0, 0 });
|
||||
|
||||
// the destination cell in the new KV cache
|
||||
uint32_t id = 0;
|
||||
|
||||
// number of pairs merged
|
||||
uint32_t n_merges = 0;
|
||||
|
||||
// determine which KV cells to merge
|
||||
for (uint32_t i0 = 0; i0 < kv_size; ++i0) {
|
||||
const auto & cell0 = kv_self.cells[i0];
|
||||
|
||||
if (!cell0.is_empty() && !infos[i0].merged) {
|
||||
infos[i0] = { true, id, 0, 0 };
|
||||
infos[id].cnt = 1;
|
||||
|
||||
const llama_pos p0 = cell0.pos;
|
||||
|
||||
for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) {
|
||||
const auto & cell1 = kv_self.cells[i1];
|
||||
|
||||
if (i0 != i1 && cell0.is_same_seq(cell1)) {
|
||||
const llama_pos p1 = cell1.pos;
|
||||
|
||||
if (std::abs(p0 - p1) <= kv_self.compress_delta) {
|
||||
infos[i1] = { true, id, 0, 0 };
|
||||
infos[id].cnt++;
|
||||
n_merges++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (i0 != id) {
|
||||
kv_self.cells[id] = cell0;
|
||||
}
|
||||
|
||||
id++;
|
||||
}
|
||||
}
|
||||
|
||||
kv_self.head = id;
|
||||
kv_self.used = id;
|
||||
|
||||
for (uint32_t i = id; i < kv_size; ++i) {
|
||||
kv_self.cells[i] = llama_kv_cell();
|
||||
}
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges);
|
||||
|
||||
ggml_type_traits_t tt_k;
|
||||
ggml_type_traits_t tt_v;
|
||||
|
||||
tt_k = ggml_internal_get_type_traits(kv_self.type_k);
|
||||
tt_v = ggml_internal_get_type_traits(kv_self.type_v);
|
||||
|
||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
infos[i].r = 0;
|
||||
}
|
||||
|
||||
// update keys
|
||||
{
|
||||
const int64_t ne = n_embd_k_gqa*kv_size;
|
||||
|
||||
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne);
|
||||
|
||||
buf_q.resize(k_size);
|
||||
|
||||
buf_src_f32.resize(ne);
|
||||
buf_dst_f32.resize(ne);
|
||||
|
||||
ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
|
||||
|
||||
tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne);
|
||||
|
||||
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
if (!infos[i].merged) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t id = infos[i].id;
|
||||
|
||||
#if 1
|
||||
// merge using averaging
|
||||
{
|
||||
const float scale = 1.0f/float(infos[id].cnt);
|
||||
|
||||
const int64_t os = i*n_embd_k_gqa;
|
||||
const int64_t od = id*n_embd_k_gqa;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_k_gqa; ++j) {
|
||||
buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
|
||||
}
|
||||
}
|
||||
#else
|
||||
// merge separate heads
|
||||
{
|
||||
for (uint32_t h = 0; h < n_head_kv; ++h) {
|
||||
if ((h + il) % infos[id].cnt != infos[id].r) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k;
|
||||
const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_head_k; ++j) {
|
||||
buf_dst_f32[od + j] = buf_src_f32[os + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
infos[id].r++;
|
||||
#endif
|
||||
}
|
||||
|
||||
tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
||||
|
||||
ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
infos[i].r = 0;
|
||||
}
|
||||
|
||||
// update values (note: they are transposed)
|
||||
{
|
||||
const int64_t ne = n_embd_v_gqa*kv_size;
|
||||
|
||||
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne);
|
||||
|
||||
buf_q.resize(v_size);
|
||||
|
||||
buf_src_f32.resize(ne);
|
||||
buf_dst_f32.resize(ne);
|
||||
|
||||
ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
|
||||
|
||||
tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne);
|
||||
|
||||
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
|
||||
|
||||
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||
if (!infos[i].merged) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t id = infos[i].id;
|
||||
|
||||
#if 1
|
||||
// merge using averaging
|
||||
{
|
||||
const float scale = 1.0f/float(infos[id].cnt);
|
||||
//printf("i: %d -> id: %d, scale: %f\n", i, id, scale);
|
||||
|
||||
const int64_t os = i;
|
||||
const int64_t od = id;
|
||||
|
||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||
buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
|
||||
}
|
||||
}
|
||||
#else
|
||||
// merge separate heads
|
||||
{
|
||||
for (uint32_t h = 0; h < n_head_kv; ++h) {
|
||||
if ((h + il) % infos[id].cnt != infos[id].r) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const int64_t os = i;
|
||||
const int64_t od = id;
|
||||
|
||||
for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) {
|
||||
buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
infos[id].r++;
|
||||
#endif
|
||||
}
|
||||
|
||||
tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
||||
|
||||
ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
|
||||
}
|
||||
}
|
||||
|
||||
const int64_t t_end = ggml_time_us();
|
||||
|
||||
LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||
|
||||
kv_self.compress_delta = -1;
|
||||
lctx.kv_self.do_defrag = false;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -12360,6 +12497,10 @@ void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) {
|
|||
llama_kv_cache_compress(ctx->kv_self, delta);
|
||||
}
|
||||
|
||||
void llama_kv_cache_defrag(struct llama_context * ctx) {
|
||||
llama_kv_cache_defrag(ctx->kv_self);
|
||||
}
|
||||
|
||||
void llama_kv_cache_update(struct llama_context * ctx) {
|
||||
llama_kv_cache_update_internal(*ctx);
|
||||
}
|
||||
|
|
11
llama.h
11
llama.h
|
@ -555,11 +555,20 @@ extern "C" {
|
|||
llama_seq_id seq_id);
|
||||
|
||||
// [EXPERIMENTAL] Compress the data in the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_cache_update()
|
||||
LLAMA_API void llama_kv_cache_compress(
|
||||
struct llama_context * ctx,
|
||||
llama_pos delta);
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts) to the KV data
|
||||
// Defragment the KV cache
|
||||
// This will be applied:
|
||||
// - lazily on next llama_decode()
|
||||
// - explicitly with llama_kv_cache_update()
|
||||
LLAMA_API void llama_kv_cache_defrag(struct llama_context * ctx);
|
||||
|
||||
// Apply the KV cache updates (such as K-shifts, defragmentation, etc.)
|
||||
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
|
||||
|
||||
//
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue