llama : disallow incompatible states

This commit is contained in:
Georgi Gerganov 2024-04-25 19:53:57 +03:00
parent bab346ba69
commit 0fc5c5eb74
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 7 additions and 1 deletions

View file

@ -16323,11 +16323,13 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
const uint32_t kv_size = kv_self.size;
const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
const uint32_t kv_used = kv_self.used;
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
data_ctx->write(&kv_head, sizeof(kv_head));
data_ctx->write(&kv_size, sizeof(kv_size));
data_ctx->write(&kv_used, sizeof(kv_used));
data_ctx->write(&v_trans, sizeof(v_trans));
if (kv_buf_size) {
const size_t pre_kv_buf_size = data_ctx->get_size_written();
@ -16473,11 +16475,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
uint32_t kv_head;
uint32_t kv_size;
uint32_t kv_used;
uint32_t v_trans;
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans);
GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition
if (kv_self.size != kv_size) {
// the KV cache needs to be big enough to load all the KV cells from the saved state

View file

@ -40,7 +40,7 @@
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 5
#define LLAMA_SESSION_VERSION 6
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 1