llama : allow loading state saved with a different ctx size

When loading a session file, the context size is now only required to be
at least enough to load the KV cells contained in that session file,
instead of requiring to use exactly the same context size as when saving.

Doing this enables the use-case of extending or shrinking the context size
of a saved session.

This breaks existing session files because the meaning of kv_buf_size
is slightly changed (previously it was the size of the whole KV cache,
now it's only the size of the saved part of it). This allows for
finer-grained sanity checks when loading in an effort to keep kv_buf_size
useful even when the kv_size is changed.
This commit is contained in:
Francis Couture-Harpin 2024-03-25 23:13:50 -04:00
parent ffa9abd9c3
commit e9095aca20

View file

@ -14714,9 +14714,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s(); const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const size_t kv_buf_size = kv_self.total_size(); // NOTE: kv_size and kv_buf_size are mostly used for sanity checks
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self); const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
const uint32_t kv_size = kv_self.size; const uint32_t kv_size = kv_self.size;
const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
const uint32_t kv_used = kv_self.used; const uint32_t kv_used = kv_self.used;
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
@ -14725,6 +14726,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(&kv_used, sizeof(kv_used)); data_ctx->write(&kv_used, sizeof(kv_used));
if (kv_buf_size) { if (kv_buf_size) {
const size_t pre_kv_buf_size = data_ctx->get_size_written();
std::vector<uint8_t> tmp_buf; std::vector<uint8_t> tmp_buf;
for (int il = 0; il < (int) n_layer; ++il) { for (int il = 0; il < (int) n_layer; ++il) {
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head); const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
@ -14754,6 +14756,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
data_ctx->write(tmp_buf.data(), tmp_buf.size()); data_ctx->write(tmp_buf.data(), tmp_buf.size());
} }
} }
GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size);
} }
for (uint32_t i = 0; i < kv_head; ++i) { for (uint32_t i = 0; i < kv_head; ++i) {
@ -14867,8 +14870,18 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used); memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
if (kv_self.size != kv_size) {
// the KV cache needs to be big enough to load all the KV cells from the saved state
GGML_ASSERT(kv_self.size >= kv_head);
LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n",
__func__, kv_head, kv_size, kv_self.size);
}
if (kv_buf_size) { if (kv_buf_size) {
GGML_ASSERT(kv_self.total_size() == kv_buf_size); const size_t pre_kv_buf_size = inp - src;
GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
for (int il = 0; il < (int) n_layer; ++il) { for (int il = 0; il < (int) n_layer; ++il) {
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head); const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
@ -14888,23 +14901,21 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
// v is not contiguous, copy row by row // v is not contiguous, copy row by row
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head); const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size); const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size);
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) { for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size); ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
inp += v_row_size; inp += v_row_size;
} }
} }
GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
} }
GGML_ASSERT(kv_self.size == kv_size); llama_kv_cache_clear(ctx);
ctx->kv_self.head = kv_head; ctx->kv_self.head = kv_head;
ctx->kv_self.size = kv_size;
ctx->kv_self.used = kv_used; ctx->kv_self.used = kv_used;
ctx->kv_self.cells.resize(kv_size);
for (uint32_t i = 0; i < kv_head; ++i) { for (uint32_t i = 0; i < kv_head; ++i) {
llama_pos pos; llama_pos pos;
size_t seq_id_size; size_t seq_id_size;
@ -14921,11 +14932,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
ctx->kv_self.cells[i].seq_id.insert(seq_id); ctx->kv_self.cells[i].seq_id.insert(seq_id);
} }
} }
for (uint32_t i = kv_head; i < kv_size; ++i) {
ctx->kv_self.cells[i].pos = -1;
ctx->kv_self.cells[i].seq_id.clear();
}
} }
const size_t nread = inp - src; const size_t nread = inp - src;