From be714a0fdaa78e35c74ff539474d1ebcf564ab5a Mon Sep 17 00:00:00 2001 From: Jan Boon Date: Tue, 2 Apr 2024 04:17:15 +0800 Subject: [PATCH] check types for stricter restore --- llama.cpp | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 1dbb9a93d..726db8218 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15119,6 +15119,8 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) } for (int il = 0; il < (int)n_layer; ++il) { + // types of keys and values + s_cell_data_size += sizeof(int32_t) * 2; // k_size_row and v_size_el values of layer s_cell_data_size += sizeof(size_t) * 2; @@ -15210,6 +15212,10 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam // Get whole range at a time std::vector tmp_buf; for (int il = 0; il < (int)n_layer; ++il) { + // Write key type + const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + data_ctx.write(&k_type_i, sizeof(k_type_i)); + // Write row size of key const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); data_ctx.write(&k_size_row, sizeof(k_size_row)); @@ -15226,6 +15232,10 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam // For the values, they are transposed, so we also need the element size and get the element ranges from each row const uint32_t kv_size = kv_self.size; for (int il = 0; il < (int)n_layer; ++il) { + // Write value type + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + data_ctx.write(&v_type_i, sizeof(v_type_i)); + // Write element size const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); data_ctx.write(&v_size_el, sizeof(v_size_el)); @@ -15334,6 +15344,17 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo for (int il = 0; il < (int)n_layer; ++il) { + // Read type of key + int32_t k_type_i_ref; + memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref)); + inp += sizeof(k_type_i_ref); + const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; + if (k_type_i != k_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return 0; + } + // Read row size of key size_t k_size_row_ref; memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref)); @@ -15354,11 +15375,21 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, // For each layer, read the values for each cell (transposed) for (int il = 0; il < (int)n_layer; ++il) { + // Read type of value + int32_t v_type_i_ref; + memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); + inp += sizeof(v_type_i_ref); + const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; + if (v_type_i != v_type_i_ref) { + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return 0; + } + // Read element size of value size_t v_size_el_ref; memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); inp += sizeof(v_size_el_ref); - const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); if (v_size_el != v_size_el_ref) { llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);