llama : allow loading state saved with a different ctx size
When loading a session file, the context size is now only required to be at least enough to load the KV cells contained in that session file, instead of requiring to use exactly the same context size as when saving. Doing this enables the use-case of extending or shrinking the context size of a saved session. This breaks existing session files because the meaning of kv_buf_size is slightly changed (previously it was the size of the whole KV cache, now it's only the size of the saved part of it). This allows for finer-grained sanity checks when loading in an effort to keep kv_buf_size useful even when the kv_size is changed.
This commit is contained in:
parent
ffa9abd9c3
commit
e9095aca20
1 changed files with 18 additions and 12 deletions
30
llama.cpp
30
llama.cpp
|
@ -14714,9 +14714,10 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
|
||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
|
||||
|
||||
const size_t kv_buf_size = kv_self.total_size();
|
||||
// NOTE: kv_size and kv_buf_size are mostly used for sanity checks
|
||||
const uint32_t kv_head = llama_kv_cache_cell_max(kv_self);
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
|
||||
const uint32_t kv_used = kv_self.used;
|
||||
|
||||
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
|
||||
|
@ -14725,6 +14726,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||
data_ctx->write(&kv_used, sizeof(kv_used));
|
||||
|
||||
if (kv_buf_size) {
|
||||
const size_t pre_kv_buf_size = data_ctx->get_size_written();
|
||||
std::vector<uint8_t> tmp_buf;
|
||||
for (int il = 0; il < (int) n_layer; ++il) {
|
||||
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
|
||||
|
@ -14754,6 +14756,7 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
|||
data_ctx->write(tmp_buf.data(), tmp_buf.size());
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(kv_buf_size == data_ctx->get_size_written() - pre_kv_buf_size);
|
||||
}
|
||||
|
||||
for (uint32_t i = 0; i < kv_head; ++i) {
|
||||
|
@ -14867,8 +14870,18 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
|
||||
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
|
||||
|
||||
if (kv_self.size != kv_size) {
|
||||
// the KV cache needs to be big enough to load all the KV cells from the saved state
|
||||
GGML_ASSERT(kv_self.size >= kv_head);
|
||||
|
||||
LLAMA_LOG_INFO("%s: state contains %d KV cells, was saved with kv_size=%d, but is loaded with kv_size=%d (fine, but different)\n",
|
||||
__func__, kv_head, kv_size, kv_self.size);
|
||||
}
|
||||
|
||||
if (kv_buf_size) {
|
||||
GGML_ASSERT(kv_self.total_size() == kv_buf_size);
|
||||
const size_t pre_kv_buf_size = inp - src;
|
||||
|
||||
GGML_ASSERT(kv_self.total_size() >= kv_buf_size);
|
||||
|
||||
for (int il = 0; il < (int) n_layer; ++il) {
|
||||
const size_t k_size = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa*kv_head);
|
||||
|
@ -14888,23 +14901,21 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||
|
||||
// v is not contiguous, copy row by row
|
||||
const size_t v_row_size = ggml_row_size(kv_self.v_l[il]->type, kv_head);
|
||||
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_size);
|
||||
const size_t v_row_stride = ggml_row_size(kv_self.v_l[il]->type, kv_self.size);
|
||||
|
||||
for (int ir = 0; ir < (int) n_embd_v_gqa; ++ir) {
|
||||
ggml_backend_tensor_set(kv_self.v_l[il], inp, ir*v_row_stride, v_row_size);
|
||||
inp += v_row_size;
|
||||
}
|
||||
}
|
||||
GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
|
||||
}
|
||||
|
||||
GGML_ASSERT(kv_self.size == kv_size);
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
ctx->kv_self.head = kv_head;
|
||||
ctx->kv_self.size = kv_size;
|
||||
ctx->kv_self.used = kv_used;
|
||||
|
||||
ctx->kv_self.cells.resize(kv_size);
|
||||
|
||||
for (uint32_t i = 0; i < kv_head; ++i) {
|
||||
llama_pos pos;
|
||||
size_t seq_id_size;
|
||||
|
@ -14921,11 +14932,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
|||
ctx->kv_self.cells[i].seq_id.insert(seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
for (uint32_t i = kv_head; i < kv_size; ++i) {
|
||||
ctx->kv_self.cells[i].pos = -1;
|
||||
ctx->kv_self.cells[i].seq_id.clear();
|
||||
}
|
||||
}
|
||||
|
||||
const size_t nread = inp - src;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue