check types for stricter restore

This commit is contained in:
Jan Boon 2024-04-02 04:17:15 +08:00
parent b3f6da3d60
commit be714a0fda

View file

@ -15119,6 +15119,8 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
} }
for (int il = 0; il < (int)n_layer; ++il) { for (int il = 0; il < (int)n_layer; ++il) {
// types of keys and values
s_cell_data_size += sizeof(int32_t) * 2;
// k_size_row and v_size_el values of layer // k_size_row and v_size_el values of layer
s_cell_data_size += sizeof(size_t) * 2; s_cell_data_size += sizeof(size_t) * 2;
@ -15210,6 +15212,10 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
// Get whole range at a time // Get whole range at a time
std::vector<uint8_t> tmp_buf; std::vector<uint8_t> tmp_buf;
for (int il = 0; il < (int)n_layer; ++il) { for (int il = 0; il < (int)n_layer; ++il) {
// Write key type
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
data_ctx.write(&k_type_i, sizeof(k_type_i));
// Write row size of key // Write row size of key
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
data_ctx.write(&k_size_row, sizeof(k_size_row)); data_ctx.write(&k_size_row, sizeof(k_size_row));
@ -15226,6 +15232,10 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
// For the values, they are transposed, so we also need the element size and get the element ranges from each row // For the values, they are transposed, so we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size; const uint32_t kv_size = kv_self.size;
for (int il = 0; il < (int)n_layer; ++il) { for (int il = 0; il < (int)n_layer; ++il) {
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i));
// Write element size // Write element size
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
data_ctx.write(&v_size_el, sizeof(v_size_el)); data_ctx.write(&v_size_el, sizeof(v_size_el));
@ -15334,6 +15344,17 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo // For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
for (int il = 0; il < (int)n_layer; ++il) { for (int il = 0; il < (int)n_layer; ++il) {
// Read type of key
int32_t k_type_i_ref;
memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
inp += sizeof(k_type_i_ref);
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
if (k_type_i != k_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
return 0;
}
// Read row size of key // Read row size of key
size_t k_size_row_ref; size_t k_size_row_ref;
memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref)); memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref));
@ -15354,11 +15375,21 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
// For each layer, read the values for each cell (transposed) // For each layer, read the values for each cell (transposed)
for (int il = 0; il < (int)n_layer; ++il) { for (int il = 0; il < (int)n_layer; ++il) {
// Read type of value
int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
inp += sizeof(v_type_i_ref);
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return 0;
}
// Read element size of value // Read element size of value
size_t v_size_el_ref; size_t v_size_el_ref;
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
inp += sizeof(v_size_el_ref); inp += sizeof(v_size_el_ref);
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (v_size_el != v_size_el_ref) { if (v_size_el != v_size_el_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);