llama : add alternative KV cache merging (EXPERIMENTAL)
This commit is contained in:
parent
0d6f8734b3
commit
9ec749df59
1 changed files with 59 additions and 8 deletions
67
llama.cpp
67
llama.cpp
|
@ -8072,10 +8072,13 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
|
|
||||||
const auto & hparams = lctx.model.hparams;
|
const auto & hparams = lctx.model.hparams;
|
||||||
|
|
||||||
const uint32_t n_layer = hparams.n_layer;
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
const uint32_t kv_size = kv_self.size;
|
const uint32_t n_embd_head_k = hparams.n_embd_head_k; GGML_UNUSED(n_embd_head_k);
|
||||||
|
const uint32_t n_embd_head_v = hparams.n_embd_head_v; GGML_UNUSED(n_embd_head_v);
|
||||||
|
const uint32_t n_head_kv = hparams.n_head_kv; GGML_UNUSED(n_head_kv);
|
||||||
|
const uint32_t kv_size = kv_self.size;
|
||||||
|
|
||||||
std::vector<uint8_t> buf_q;
|
std::vector<uint8_t> buf_q;
|
||||||
|
|
||||||
|
@ -8085,9 +8088,9 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
const int64_t t_start = ggml_time_us();
|
const int64_t t_start = ggml_time_us();
|
||||||
|
|
||||||
struct c_pair { uint32_t i0, i1; };
|
struct c_pair { uint32_t i0, i1; };
|
||||||
struct c_info { bool merged; uint32_t id, cnt;};
|
struct c_info { bool merged; uint32_t id, cnt, r; };
|
||||||
|
|
||||||
std::vector<c_info> infos(kv_size, { false, 0, 0 });
|
std::vector<c_info> infos(kv_size, { false, 0, 0, 0 });
|
||||||
|
|
||||||
// the destination cell in the new KV cache
|
// the destination cell in the new KV cache
|
||||||
uint32_t id = 0;
|
uint32_t id = 0;
|
||||||
|
@ -8100,7 +8103,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
const auto & cell0 = kv_self.cells[i0];
|
const auto & cell0 = kv_self.cells[i0];
|
||||||
|
|
||||||
if (!cell0.is_empty() && !infos[i0].merged) {
|
if (!cell0.is_empty() && !infos[i0].merged) {
|
||||||
infos[i0] = { true, id, 0 };
|
infos[i0] = { true, id, 0, 0 };
|
||||||
infos[id].cnt = 1;
|
infos[id].cnt = 1;
|
||||||
|
|
||||||
const llama_pos p0 = cell0.pos;
|
const llama_pos p0 = cell0.pos;
|
||||||
|
@ -8112,7 +8115,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
const llama_pos p1 = cell1.pos;
|
const llama_pos p1 = cell1.pos;
|
||||||
|
|
||||||
if (std::abs(p0 - p1) <= kv_self.compress_delta) {
|
if (std::abs(p0 - p1) <= kv_self.compress_delta) {
|
||||||
infos[i1] = { true, id, 0 };
|
infos[i1] = { true, id, 0, 0 };
|
||||||
infos[id].cnt++;
|
infos[id].cnt++;
|
||||||
n_merges++;
|
n_merges++;
|
||||||
}
|
}
|
||||||
|
@ -8143,6 +8146,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
tt_v = ggml_internal_get_type_traits(kv_self.type_v);
|
tt_v = ggml_internal_get_type_traits(kv_self.type_v);
|
||||||
|
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
|
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||||
|
infos[i].r = 0;
|
||||||
|
}
|
||||||
|
|
||||||
// update keys
|
// update keys
|
||||||
{
|
{
|
||||||
const int64_t ne = n_embd_k_gqa*kv_size;
|
const int64_t ne = n_embd_k_gqa*kv_size;
|
||||||
|
@ -8167,6 +8174,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
|
|
||||||
const uint32_t id = infos[i].id;
|
const uint32_t id = infos[i].id;
|
||||||
|
|
||||||
|
#if 1
|
||||||
// merge using averaging
|
// merge using averaging
|
||||||
{
|
{
|
||||||
const float scale = 1.0f/float(infos[id].cnt);
|
const float scale = 1.0f/float(infos[id].cnt);
|
||||||
|
@ -8178,6 +8186,25 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
|
buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
// merge separate heads
|
||||||
|
{
|
||||||
|
for (uint32_t h = 0; h < n_head_kv; ++h) {
|
||||||
|
if ((h + il) % infos[id].cnt != infos[id].r) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t os = i*n_embd_k_gqa + h*n_embd_head_k;
|
||||||
|
const int64_t od = id*n_embd_k_gqa + h*n_embd_head_k;
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < n_embd_head_k; ++j) {
|
||||||
|
buf_dst_f32[od + j] = buf_src_f32[os + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
infos[id].r++;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
||||||
|
@ -8185,6 +8212,10 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
|
ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||||
|
infos[i].r = 0;
|
||||||
|
}
|
||||||
|
|
||||||
// update values (note: they are transposed)
|
// update values (note: they are transposed)
|
||||||
{
|
{
|
||||||
const int64_t ne = n_embd_v_gqa*kv_size;
|
const int64_t ne = n_embd_v_gqa*kv_size;
|
||||||
|
@ -8209,6 +8240,7 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
|
|
||||||
const uint32_t id = infos[i].id;
|
const uint32_t id = infos[i].id;
|
||||||
|
|
||||||
|
#if 1
|
||||||
// merge using averaging
|
// merge using averaging
|
||||||
{
|
{
|
||||||
const float scale = 1.0f/float(infos[id].cnt);
|
const float scale = 1.0f/float(infos[id].cnt);
|
||||||
|
@ -8221,6 +8253,25 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
|
buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
// merge separate heads
|
||||||
|
{
|
||||||
|
for (uint32_t h = 0; h < n_head_kv; ++h) {
|
||||||
|
if ((h + il) % infos[id].cnt != infos[id].r) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t os = i;
|
||||||
|
const int64_t od = id;
|
||||||
|
|
||||||
|
for (uint32_t j = h*n_embd_head_v; j < (h + 1)*n_embd_head_v; ++j) {
|
||||||
|
buf_dst_f32[od + j*kv_size] = buf_src_f32[os + j*kv_size];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
infos[id].r++;
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue