llama : llama_kv_cache_clear zeroes data + fix save-load seq

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-04-25 19:37:27 +03:00
parent ac1c6d91de
commit c225609f10
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 103 additions and 46 deletions

147
llama.cpp
View file

@ -2566,6 +2566,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
} }
cache.head = 0; cache.head = 0;
cache.used = 0; cache.used = 0;
for (auto & buf : cache.bufs) {
ggml_backend_buffer_clear(buf, 0);
}
} }
static bool llama_kv_cache_seq_rm( static bool llama_kv_cache_seq_rm(
@ -16483,6 +16487,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
__func__, kv_head, kv_size, kv_self.size); __func__, kv_head, kv_size, kv_self.size);
} }
llama_kv_cache_clear(ctx);
if (kv_buf_size) { if (kv_buf_size) {
const size_t pre_kv_buf_size = inp - src; const size_t pre_kv_buf_size = inp - src;
@ -16516,8 +16522,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size); GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
} }
llama_kv_cache_clear(ctx);
ctx->kv_self.head = kv_head; ctx->kv_self.head = kv_head;
ctx->kv_self.used = kv_used; ctx->kv_self.used = kv_used;
@ -16777,28 +16781,48 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
} }
} }
// For the values, they are transposed, so we also need the element size and get the element ranges from each row if (!kv_self.v_trans) {
const uint32_t kv_size = kv_self.size; for (int il = 0; il < (int)n_layer; ++il) {
for (int il = 0; il < (int)n_layer; ++il) { // Write key type
// Write value type const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; data_ctx.write(&v_type_i, sizeof(v_type_i));
data_ctx.write(&v_type_i, sizeof(v_type_i));
// Write element size // Write row size of key
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
data_ctx.write(&v_size_el, sizeof(v_size_el)); data_ctx.write(&v_size_row, sizeof(v_size_row));
// For each row, we get the element values of each cell // Read each range of cells of v_size length each into tmp_buf and write out
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
// Read each range of cells of v_size_el length each into tmp_buf and write out
for (const auto & range : cell_ranges) { for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first; const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el; tmp_buf.resize(range_size * v_size_row);
tmp_buf.resize(range_size * v_size_el); ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
data_ctx.write(tmp_buf.data(), tmp_buf.size()); data_ctx.write(tmp_buf.data(), tmp_buf.size());
} }
} }
} else {
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (int il = 0; il < (int)n_layer; ++il) {
// Write value type
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
data_ctx.write(&v_type_i, sizeof(v_type_i));
// Write element size
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
data_ctx.write(&v_size_el, sizeof(v_size_el));
// For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
// Read each range of cells of v_size_el length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
tmp_buf.resize(range_size * v_size_el);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
data_ctx.write(tmp_buf.data(), tmp_buf.size());
}
}
}
} }
return data_ctx.get_size_written(); return data_ctx.get_size_written();
@ -16923,41 +16947,74 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
} }
} }
// For each layer, read the values for each cell (transposed) if (!kv_self.v_trans) {
for (int il = 0; il < (int)n_layer; ++il) { for (int il = 0; il < (int)n_layer; ++il) {
// Read type of value // Read type of key
int32_t v_type_i_ref; int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref)); memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
inp += sizeof(v_type_i_ref); inp += sizeof(v_type_i_ref);
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) { if (v_type_i != v_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return 0; return 0;
} }
// Read element size of value // Read row size of key
size_t v_size_el_ref; size_t v_size_row_ref;
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref)); memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
inp += sizeof(v_size_el_ref); inp += sizeof(v_size_row_ref);
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type); const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
if (v_size_el != v_size_el_ref) { if (v_size_row != v_size_row_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il); LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
return 0; return 0;
} }
if (cell_count) { if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range // Read and set the keys for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
const size_t dst_offset = (kv_head + j * kv_size) * v_size_el; inp += cell_count * v_size_row;
ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el); }
inp += cell_count * v_size_el; }
} else {
// For each layer, read the values for each cell (transposed)
for (int il = 0; il < (int)n_layer; ++il) {
// Read type of value
int32_t v_type_i_ref;
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
inp += sizeof(v_type_i_ref);
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return 0;
}
// Read element size of value
size_t v_size_el_ref;
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
inp += sizeof(v_size_el_ref);
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
if (v_size_el != v_size_el_ref) {
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
return 0;
}
if (cell_count) {
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
inp += cell_count * v_size_el;
}
} }
} }
} }
const size_t nread = inp - src; const size_t nread = inp - src;
return nread; return nread;
} }

View file

@ -526,7 +526,7 @@ extern "C" {
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them) // Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx); LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache // Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_cache_clear( LLAMA_API void llama_kv_cache_clear(
struct llama_context * ctx); struct llama_context * ctx);