llama : disallow incompatible states
This commit is contained in:
parent
bab346ba69
commit
0fc5c5eb74
2 changed files with 7 additions and 1 deletions
|
@ -16323,11 +16323,13 @@ static void llama_state_get_data_internal(struct llama_context * ctx, llama_data
|
||||||
const uint32_t kv_size = kv_self.size;
|
const uint32_t kv_size = kv_self.size;
|
||||||
const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
|
const size_t kv_buf_size = kv_self.total_size() / (kv_size ? kv_size : 1) * kv_head;
|
||||||
const uint32_t kv_used = kv_self.used;
|
const uint32_t kv_used = kv_self.used;
|
||||||
|
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
|
||||||
|
|
||||||
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
|
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
|
||||||
data_ctx->write(&kv_head, sizeof(kv_head));
|
data_ctx->write(&kv_head, sizeof(kv_head));
|
||||||
data_ctx->write(&kv_size, sizeof(kv_size));
|
data_ctx->write(&kv_size, sizeof(kv_size));
|
||||||
data_ctx->write(&kv_used, sizeof(kv_used));
|
data_ctx->write(&kv_used, sizeof(kv_used));
|
||||||
|
data_ctx->write(&v_trans, sizeof(v_trans));
|
||||||
|
|
||||||
if (kv_buf_size) {
|
if (kv_buf_size) {
|
||||||
const size_t pre_kv_buf_size = data_ctx->get_size_written();
|
const size_t pre_kv_buf_size = data_ctx->get_size_written();
|
||||||
|
@ -16473,11 +16475,15 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
uint32_t kv_head;
|
uint32_t kv_head;
|
||||||
uint32_t kv_size;
|
uint32_t kv_size;
|
||||||
uint32_t kv_used;
|
uint32_t kv_used;
|
||||||
|
uint32_t v_trans;
|
||||||
|
|
||||||
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
|
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
|
||||||
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
|
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
|
||||||
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
|
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
|
||||||
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
|
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
|
||||||
|
memcpy(&v_trans, inp, sizeof(v_trans)); inp += sizeof(v_trans);
|
||||||
|
|
||||||
|
GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition
|
||||||
|
|
||||||
if (kv_self.size != kv_size) {
|
if (kv_self.size != kv_size) {
|
||||||
// the KV cache needs to be big enough to load all the KV cells from the saved state
|
// the KV cache needs to be big enough to load all the KV cells from the saved state
|
||||||
|
|
2
llama.h
2
llama.h
|
@ -40,7 +40,7 @@
|
||||||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 5
|
#define LLAMA_SESSION_VERSION 6
|
||||||
|
|
||||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
#define LLAMA_STATE_SEQ_VERSION 1
|
#define LLAMA_STATE_SEQ_VERSION 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue