From e9095aca209c968afe09654c428223c896abee65 Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Mon, 25 Mar 2024 23:13:50 -0400 Subject: [PATCH] llama : allow loading state saved with a different ctx size When loading a session file, the context size is now only required to be at least enough to load the KV cells contained in that session file, instead of requiring to use exactly the same context size as when saving. Doing this enables the use-case of extending or shrinking the context size of a saved session. This breaks existing session files because the meaning of kv_buf_size is slightly changed (previously it was the size of the whole KV cache, now it's only the size of the saved part of it). This allows for finer-grained sanity checks when loading in an effort to keep kv_buf_size useful even when the kv_size is changed. --- llama.cpp | 30 ++++++++++++++++++------------ 1 file changed, 18 insertions(+), 12 deletions(-) diff --git a/llama.cpp b/llama.cpp index 30b778874..18116b962 100644 --- a/llama.cpp +++ b/llama.cpp @@ -14714,9 +14714,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); - const size_t kv_buf_size = kv_self.total_size(); + // NOTE: kv_size and kv_buf_size are mostly used for sanity checks const uint32_t kv_head = llama_kv_cache_cell_max(kv_self); const uint32_t kv_size = kv_self.size; + const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head; const uint32_t kv_used = kv_self.used; data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); @@ -14725,6 +14726,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat data_ctx->write(&kv_used, sizeof(kv_used)); if (kv_buf_size) { + const size_t pre_kv_buf_size = data_ctx->get_size_written(); std::vector tmp_buf; for (int il = 0; il < (int) n_layer; ++il) { const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head); @@ -14754,6 +14756,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat data_ctx->write(tmp_buf.data(), tmp_buf.size()); } } + GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size); } for (uint32_t i = 0; i < kv_head; ++i) { @@ -14867,8 +14870,18 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); + if (kv_self.size != kv_size) { + // the KV cache needs to be big enough to load all the KV cells from the saved state + GGML_ASSERT(kv_self.size >= kv_head); + + LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n", + __func__, kv_head, kv_size, kv_self.size); + } + if (kv_buf_size) { - GGML_ASSERT(kv_self.total_size() == kv_buf_size); + const size_t pre_kv_buf_size = inp - src; + + GGML_ASSERT(kv_self.total_size() >= kv_buf_size); for (int il = 0; il < (int) n_layer; ++il) { const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head); @@ -14888,23 +14901,21 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { // v is not contiguous, copy row by row const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); - const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size); + const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size); for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) { ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size); inp += v_row_size; } } + GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size); } - GGML_ASSERT(kv_self.size == kv_size); + llama_kv_cache_clear(ctx); ctx->kv_self.head = kv_head; - ctx->kv_self.size = kv_size; ctx->kv_self.used = kv_used; - ctx->kv_self.cells.resize(kv_size); - for (uint32_t i = 0; i < kv_head; ++i) { llama_pos pos; size_t seq_id_size; @@ -14921,11 +14932,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { ctx->kv_self.cells[i].seq_id.insert(seq_id); } } - - for (uint32_t i = kv_head; i < kv_size; ++i) { - ctx->kv_self.cells[i].pos = -1; - ctx->kv_self.cells[i].seq_id.clear(); - } } const size_t nread = inp - src;