check types for stricter restore
This commit is contained in:
parent
b3f6da3d60
commit
be714a0fda
1 changed files with 32 additions and 1 deletions
33
llama.cpp
33
llama.cpp
|
@ -15119,6 +15119,8 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
|
|||
}
|
||||
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
// types of keys and values
|
||||
s_cell_data_size += sizeof(int32_t) * 2;
|
||||
// k_size_row and v_size_el values of layer
|
||||
s_cell_data_size += sizeof(size_t) * 2;
|
||||
|
||||
|
@ -15210,6 +15212,10 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
|||
// Get whole range at a time
|
||||
std::vector<uint8_t> tmp_buf;
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
// Write key type
|
||||
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
||||
data_ctx.write(&k_type_i, sizeof(k_type_i));
|
||||
|
||||
// Write row size of key
|
||||
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
|
||||
data_ctx.write(&k_size_row, sizeof(k_size_row));
|
||||
|
@ -15226,6 +15232,10 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
|||
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
|
||||
const uint32_t kv_size = kv_self.size;
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
// Write value type
|
||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
||||
|
||||
// Write element size
|
||||
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
||||
data_ctx.write(&v_size_el, sizeof(v_size_el));
|
||||
|
@ -15334,6 +15344,17 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
|||
|
||||
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
// Read type of key
|
||||
int32_t k_type_i_ref;
|
||||
memcpy(&k_type_i_ref, inp, sizeof(k_type_i_ref));
|
||||
inp += sizeof(k_type_i_ref);
|
||||
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
||||
if (k_type_i != k_type_i_ref) {
|
||||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Read row size of key
|
||||
size_t k_size_row_ref;
|
||||
memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref));
|
||||
|
@ -15354,11 +15375,21 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
|||
|
||||
// For each layer, read the values for each cell (transposed)
|
||||
for (int il = 0; il < (int)n_layer; ++il) {
|
||||
// Read type of value
|
||||
int32_t v_type_i_ref;
|
||||
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
||||
inp += sizeof(v_type_i_ref);
|
||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||
if (v_type_i != v_type_i_ref) {
|
||||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Read element size of value
|
||||
size_t v_size_el_ref;
|
||||
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
|
||||
inp += sizeof(v_size_el_ref);
|
||||
|
||||
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
||||
if (v_size_el != v_size_el_ref) {
|
||||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue