llama : llama_kv_cache_clear zeroes data + fix save-load seq

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-25 19:37:27 +03:00
parent ac1c6d91de
commit c225609f10
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 103 additions and 46 deletions

147
llama.cpp
View file

@ -2566,6 +2566,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
}
cache.head = 0;
cache.used = 0;
for (auto & buf : cache.bufs) {
ggml_backend_buffer_clear(buf, 0);
}
}
static bool llama_kv_cache_seq_rm(
@ -16483,6 +16487,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
__func__, kv_head, kv_size, kv_self.size);
}
llama_kv_cache_clear(ctx);
if (kv_buf_size) {
const size_t pre_kv_buf_size = inp - src;
@ -16516,8 +16522,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
}
llama_kv_cache_clear(ctx);
ctx->kv_self.head = kv_head;
ctx->kv_self.used = kv_used;
@ -16777,28 +16781,48 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
}
}
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (int il = 0; il < (int)n_layer; ++il) {
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i));
if (!kv_self.v_trans) {
for (int il = 0; il < (int)n_layer; ++il) {
// Write key type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i));
// Write element size
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
data_ctx.write(&v_size_el, sizeof(v_size_el));
// Write row size of key
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
data_ctx.write(&v_size_row, sizeof(v_size_row));
// For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
// Read each range of cells of v_size_el length each into tmp_buf and write out
// Read each range of cells of v_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
tmp_buf.resize(range_size * v_size_el);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
tmp_buf.resize(range_size * v_size_row);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
data_ctx.write(tmp_buf.data(), tmp_buf.size());
}
}
} else {
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (int il = 0; il < (int)n_layer; ++il) {
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i));
// Write element size
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
data_ctx.write(&v_size_el, sizeof(v_size_el));
// For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
// Read each range of cells of v_size_el length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
tmp_buf.resize(range_size * v_size_el);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
data_ctx.write(tmp_buf.data(), tmp_buf.size());
}
}
}
}
return data_ctx.get_size_written();
@ -16923,41 +16947,74 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
}
}
// For each layer, read the values for each cell (transposed)
for (int il = 0; il < (int)n_layer; ++il) {
// Read type of value
int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
inp += sizeof(v_type_i_ref);
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return 0;
}
if (!kv_self.v_trans) {
for (int il = 0; il < (int)n_layer; ++il) {
// Read type of key
int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
inp += sizeof(v_type_i_ref);
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return 0;
}
// Read element size of value
size_t v_size_el_ref;
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
inp += sizeof(v_size_el_ref);
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (v_size_el != v_size_el_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
return 0;
}
// Read row size of key
size_t v_size_row_ref;
memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
inp += sizeof(v_size_row_ref);
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
if (v_size_row != v_size_row_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
return 0;
}
if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
inp += cell_count * v_size_el;
if (cell_count) {
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
inp += cell_count * v_size_row;
}
}
} else {
// For each layer, read the values for each cell (transposed)
for (int il = 0; il < (int)n_layer; ++il) {
// Read type of value
int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
inp += sizeof(v_type_i_ref);
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return 0;
}
// Read element size of value
size_t v_size_el_ref;
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
inp += sizeof(v_size_el_ref);
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (v_size_el != v_size_el_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
return 0;
}
if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
inp += cell_count * v_size_el;
}
}
}
}
const size_t nread = inp - src;
return nread;
}

View file

@ -526,7 +526,7 @@ extern "C" {
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache
// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx);