llama : add llama_kv_cache_compress (EXPERIMENTAL)
This commit is contained in:
parent
715a343343
commit
fdfa5bc76b
3 changed files with 215 additions and 7 deletions
|
@ -146,9 +146,10 @@ int main(int argc, char ** argv) {
|
||||||
const int ib = i/n_batch - 1;
|
const int ib = i/n_batch - 1;
|
||||||
const int bd = n_batch_grp*(n_grp - 1);
|
const int bd = n_batch_grp*(n_grp - 1);
|
||||||
|
|
||||||
llama_kv_cache_seq_add(ctx, 0, n_past - n_batch, n_past, ib*bd);
|
llama_kv_cache_seq_add (ctx, 0, n_past - n_batch, n_past, ib*bd);
|
||||||
llama_kv_cache_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
llama_kv_cache_seq_div (ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
|
||||||
llama_kv_cache_update (ctx);
|
llama_kv_cache_compress(ctx, 0);
|
||||||
|
llama_kv_cache_update (ctx);
|
||||||
|
|
||||||
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
n_past = llama_kv_cache_seq_pos_max(ctx, 0) + 1;
|
||||||
}
|
}
|
||||||
|
|
210
llama.cpp
210
llama.cpp
|
@ -1733,6 +1733,12 @@ struct llama_kv_cache {
|
||||||
// computed before each graph build
|
// computed before each graph build
|
||||||
uint32_t n = 0;
|
uint32_t n = 0;
|
||||||
|
|
||||||
|
ggml_type type_k = GGML_TYPE_F16;
|
||||||
|
ggml_type type_v = GGML_TYPE_F16;
|
||||||
|
|
||||||
|
// if non-negative, compress data on next update
|
||||||
|
llama_pos compress_delta = -1;
|
||||||
|
|
||||||
std::vector<llama_kv_cell> cells;
|
std::vector<llama_kv_cell> cells;
|
||||||
|
|
||||||
std::vector<struct ggml_tensor *> k_l; // per layer
|
std::vector<struct ggml_tensor *> k_l; // per layer
|
||||||
|
@ -1968,8 +1974,8 @@ struct llama_context {
|
||||||
static bool llama_kv_cache_init(
|
static bool llama_kv_cache_init(
|
||||||
struct llama_kv_cache & cache,
|
struct llama_kv_cache & cache,
|
||||||
const llama_model & model,
|
const llama_model & model,
|
||||||
ggml_type ktype,
|
ggml_type type_k,
|
||||||
ggml_type vtype,
|
ggml_type type_v,
|
||||||
uint32_t n_ctx,
|
uint32_t n_ctx,
|
||||||
bool offload) {
|
bool offload) {
|
||||||
const struct llama_hparams & hparams = model.hparams;
|
const struct llama_hparams & hparams = model.hparams;
|
||||||
|
@ -1984,6 +1990,9 @@ static bool llama_kv_cache_init(
|
||||||
cache.size = n_ctx;
|
cache.size = n_ctx;
|
||||||
cache.used = 0;
|
cache.used = 0;
|
||||||
|
|
||||||
|
cache.type_k = type_k;
|
||||||
|
cache.type_v = type_v;
|
||||||
|
|
||||||
cache.cells.clear();
|
cache.cells.clear();
|
||||||
cache.cells.resize(n_ctx);
|
cache.cells.resize(n_ctx);
|
||||||
|
|
||||||
|
@ -2024,8 +2033,8 @@ static bool llama_kv_cache_init(
|
||||||
|
|
||||||
for (int i = 0; i < (int) n_layer; i++) {
|
for (int i = 0; i < (int) n_layer; i++) {
|
||||||
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
struct ggml_context * ctx = offload ? ctx_map.at(model.buft_layer[i].buft) : cache.ctxs.front();
|
||||||
ggml_tensor * k = ggml_new_tensor_1d(ctx, ktype, n_embd_k_gqa*n_ctx);
|
ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*n_ctx);
|
||||||
ggml_tensor * v = ggml_new_tensor_1d(ctx, vtype, n_embd_v_gqa*n_ctx);
|
ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*n_ctx);
|
||||||
ggml_format_name(k, "cache_k_l%d", i);
|
ggml_format_name(k, "cache_k_l%d", i);
|
||||||
ggml_format_name(v, "cache_v_l%d", i);
|
ggml_format_name(v, "cache_v_l%d", i);
|
||||||
cache.k_l.push_back(k);
|
cache.k_l.push_back(k);
|
||||||
|
@ -2265,6 +2274,10 @@ static llama_pos llama_kv_cache_seq_pos_max(struct llama_kv_cache & cache, llama
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void llama_kv_cache_compress(struct llama_kv_cache & cache, llama_pos delta) {
|
||||||
|
cache.compress_delta = delta;
|
||||||
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// model loading and saving
|
// model loading and saving
|
||||||
//
|
//
|
||||||
|
@ -8034,6 +8047,191 @@ static void llama_kv_cache_update_internal(struct llama_context & lctx) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// compress the KV cache data if needed:
|
||||||
|
//
|
||||||
|
// - determine which KV cell pairs (i0, i1) to merge:
|
||||||
|
//
|
||||||
|
// abs(cell[i0].pos - cell[i1].pos) <= compress_delta
|
||||||
|
//
|
||||||
|
// - move the KV cache to the Host memory for easier maniiplation
|
||||||
|
// - processing is done layer-by-layer
|
||||||
|
// - convert the KV data to F32
|
||||||
|
// - merge the KV data (different ways to merge)
|
||||||
|
// - convert the KV data back to the original type
|
||||||
|
// - move the KV cache back to the device memory
|
||||||
|
// - update the KV cache metadata
|
||||||
|
//
|
||||||
|
// as a side effect, the new KV cache is defragmented
|
||||||
|
//
|
||||||
|
if (lctx.kv_self.compress_delta >= 0) {
|
||||||
|
auto & kv_self = lctx.kv_self;
|
||||||
|
|
||||||
|
const auto & hparams = lctx.model.hparams;
|
||||||
|
|
||||||
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||||
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
const uint32_t kv_size = kv_self.size;
|
||||||
|
|
||||||
|
std::vector<uint8_t> buf_q;
|
||||||
|
|
||||||
|
std::vector<float> buf_src_f32;
|
||||||
|
std::vector<float> buf_dst_f32;
|
||||||
|
|
||||||
|
const int64_t t_start = ggml_time_us();
|
||||||
|
|
||||||
|
struct c_pair { uint32_t i0, i1; };
|
||||||
|
struct c_info { bool merged; uint32_t id, cnt;};
|
||||||
|
|
||||||
|
std::vector<c_info> infos(kv_size, { false, 0, 0 });
|
||||||
|
|
||||||
|
// the destination cell in the new KV cache
|
||||||
|
uint32_t id = 0;
|
||||||
|
|
||||||
|
// number of pairs merged
|
||||||
|
uint32_t n_merges = 0;
|
||||||
|
|
||||||
|
// determine which KV cells to merge
|
||||||
|
for (uint32_t i0 = 0; i0 < kv_size; ++i0) {
|
||||||
|
const auto & cell0 = kv_self.cells[i0];
|
||||||
|
|
||||||
|
if (!cell0.is_empty() && !infos[i0].merged) {
|
||||||
|
infos[i0] = { true, id, 0 };
|
||||||
|
infos[id].cnt = 1;
|
||||||
|
|
||||||
|
const llama_pos p0 = cell0.pos;
|
||||||
|
|
||||||
|
for (uint32_t i1 = i0 + 1; i1 < kv_size; ++i1) {
|
||||||
|
const auto & cell1 = kv_self.cells[i1];
|
||||||
|
|
||||||
|
if (i0 != i1 && cell0.is_same_seq(cell1)) {
|
||||||
|
const llama_pos p1 = cell1.pos;
|
||||||
|
|
||||||
|
if (std::abs(p0 - p1) <= kv_self.compress_delta) {
|
||||||
|
infos[i1] = { true, id, 0 };
|
||||||
|
infos[id].cnt++;
|
||||||
|
n_merges++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (i0 != id) {
|
||||||
|
kv_self.cells[id] = cell0;
|
||||||
|
}
|
||||||
|
|
||||||
|
id++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kv_self.head = id;
|
||||||
|
kv_self.used = id;
|
||||||
|
|
||||||
|
for (uint32_t i = id; i < kv_size; ++i) {
|
||||||
|
kv_self.cells[i] = llama_kv_cell();
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("(tmp log) KV compress pairs: %u\n", n_merges);
|
||||||
|
|
||||||
|
ggml_type_traits_t tt_k;
|
||||||
|
ggml_type_traits_t tt_v;
|
||||||
|
|
||||||
|
tt_k = ggml_internal_get_type_traits(kv_self.type_k);
|
||||||
|
tt_v = ggml_internal_get_type_traits(kv_self.type_v);
|
||||||
|
|
||||||
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
|
// update keys
|
||||||
|
{
|
||||||
|
const int64_t ne = n_embd_k_gqa*kv_size;
|
||||||
|
|
||||||
|
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, ne);
|
||||||
|
|
||||||
|
buf_q.resize(k_size);
|
||||||
|
|
||||||
|
buf_src_f32.resize(ne);
|
||||||
|
buf_dst_f32.resize(ne);
|
||||||
|
|
||||||
|
ggml_backend_tensor_get(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
|
||||||
|
|
||||||
|
tt_k.to_float(buf_q.data(), buf_src_f32.data(), ne);
|
||||||
|
|
||||||
|
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||||
|
if (!infos[i].merged) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t id = infos[i].id;
|
||||||
|
|
||||||
|
// merge using averaging
|
||||||
|
{
|
||||||
|
const float scale = 1.0f/float(infos[id].cnt);
|
||||||
|
|
||||||
|
const int64_t os = i*n_embd_k_gqa;
|
||||||
|
const int64_t od = id*n_embd_k_gqa;
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < n_embd_k_gqa; ++j) {
|
||||||
|
buf_dst_f32[od + j] += buf_src_f32[os + j]*scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tt_k.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
||||||
|
|
||||||
|
ggml_backend_tensor_set(kv_self.k_l[il], buf_q.data(), 0, buf_q.size());
|
||||||
|
}
|
||||||
|
|
||||||
|
// update values (note: they are transposed)
|
||||||
|
{
|
||||||
|
const int64_t ne = n_embd_v_gqa*kv_size;
|
||||||
|
|
||||||
|
const size_t v_size = ggml_row_size(kv_self.v_l[il]->type, ne);
|
||||||
|
|
||||||
|
buf_q.resize(v_size);
|
||||||
|
|
||||||
|
buf_src_f32.resize(ne);
|
||||||
|
buf_dst_f32.resize(ne);
|
||||||
|
|
||||||
|
ggml_backend_tensor_get(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
|
||||||
|
|
||||||
|
tt_v.to_float(buf_q.data(), buf_src_f32.data(), ne);
|
||||||
|
|
||||||
|
std::fill(buf_dst_f32.begin(), buf_dst_f32.end(), 0);
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < kv_size; ++i) {
|
||||||
|
if (!infos[i].merged) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
const uint32_t id = infos[i].id;
|
||||||
|
|
||||||
|
// merge using averaging
|
||||||
|
{
|
||||||
|
const float scale = 1.0f/float(infos[id].cnt);
|
||||||
|
//printf("i: %d -> id: %d, scale: %f\n", i, id, scale);
|
||||||
|
|
||||||
|
const int64_t os = i;
|
||||||
|
const int64_t od = id;
|
||||||
|
|
||||||
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||||
|
buf_dst_f32[od + j*kv_size] += buf_src_f32[os + j*kv_size]*scale;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tt_v.from_float(buf_dst_f32.data(), buf_q.data(), ne);
|
||||||
|
|
||||||
|
ggml_backend_tensor_set(kv_self.v_l[il], buf_q.data(), 0, buf_q.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const int64_t t_end = ggml_time_us();
|
||||||
|
|
||||||
|
LLAMA_LOG_INFO("(tmp log) KV compress time: %.3f ms\n", (t_end - t_start)/1000.0);
|
||||||
|
|
||||||
|
kv_self.compress_delta = -1;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
|
@ -12083,6 +12281,10 @@ llama_pos llama_kv_cache_seq_pos_max(struct llama_context * ctx, llama_seq_id se
|
||||||
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
return llama_kv_cache_seq_pos_max(ctx->kv_self, seq_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void llama_kv_cache_compress(struct llama_context * ctx, llama_pos delta) {
|
||||||
|
llama_kv_cache_compress(ctx->kv_self, delta);
|
||||||
|
}
|
||||||
|
|
||||||
void llama_kv_cache_update(struct llama_context * ctx) {
|
void llama_kv_cache_update(struct llama_context * ctx) {
|
||||||
llama_kv_cache_update_internal(*ctx);
|
llama_kv_cache_update_internal(*ctx);
|
||||||
}
|
}
|
||||||
|
|
5
llama.h
5
llama.h
|
@ -552,6 +552,11 @@ extern "C" {
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
llama_seq_id seq_id);
|
llama_seq_id seq_id);
|
||||||
|
|
||||||
|
// [EXPERIMENTAL] Compress the data in the KV cache
|
||||||
|
LLAMA_API void llama_kv_cache_compress(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
llama_pos delta);
|
||||||
|
|
||||||
// Apply the KV cache updates (such as K-shifts) to the KV data
|
// Apply the KV cache updates (such as K-shifts) to the KV data
|
||||||
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
|
LLAMA_API void llama_kv_cache_update(struct llama_context * ctx);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue