llama : llama_kv_cache_clear zeroes data + fix save-load seq
ggml-ci
This commit is contained in:
parent
ac1c6d91de
commit
c225609f10
2 changed files with 103 additions and 46 deletions
147
llama.cpp
147
llama.cpp
|
@ -2566,6 +2566,10 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
|
||||||
}
|
}
|
||||||
cache.head = 0;
|
cache.head = 0;
|
||||||
cache.used = 0;
|
cache.used = 0;
|
||||||
|
|
||||||
|
for (auto & buf : cache.bufs) {
|
||||||
|
ggml_backend_buffer_clear(buf, 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool llama_kv_cache_seq_rm(
|
static bool llama_kv_cache_seq_rm(
|
||||||
|
@ -16483,6 +16487,8 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
__func__, kv_head, kv_size, kv_self.size);
|
__func__, kv_head, kv_size, kv_self.size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
llama_kv_cache_clear(ctx);
|
||||||
|
|
||||||
if (kv_buf_size) {
|
if (kv_buf_size) {
|
||||||
const size_t pre_kv_buf_size = inp - src;
|
const size_t pre_kv_buf_size = inp - src;
|
||||||
|
|
||||||
|
@ -16516,8 +16522,6 @@ size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
|
GGML_ASSERT(kv_buf_size == inp - src - pre_kv_buf_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_kv_cache_clear(ctx);
|
|
||||||
|
|
||||||
ctx->kv_self.head = kv_head;
|
ctx->kv_self.head = kv_head;
|
||||||
ctx->kv_self.used = kv_used;
|
ctx->kv_self.used = kv_used;
|
||||||
|
|
||||||
|
@ -16777,28 +16781,48 @@ static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llam
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
|
if (!kv_self.v_trans) {
|
||||||
const uint32_t kv_size = kv_self.size;
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
// Write key type
|
||||||
// Write value type
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
||||||
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
|
||||||
|
|
||||||
// Write element size
|
// Write row size of key
|
||||||
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
||||||
data_ctx.write(&v_size_el, sizeof(v_size_el));
|
data_ctx.write(&v_size_row, sizeof(v_size_row));
|
||||||
|
|
||||||
// For each row, we get the element values of each cell
|
// Read each range of cells of v_size length each into tmp_buf and write out
|
||||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
|
||||||
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
|
||||||
for (const auto & range : cell_ranges) {
|
for (const auto & range : cell_ranges) {
|
||||||
const size_t range_size = range.second - range.first;
|
const size_t range_size = range.second - range.first;
|
||||||
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
tmp_buf.resize(range_size * v_size_row);
|
||||||
tmp_buf.resize(range_size * v_size_el);
|
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), range.first * v_size_row, range_size * v_size_row);
|
||||||
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
|
|
||||||
data_ctx.write(tmp_buf.data(), tmp_buf.size());
|
data_ctx.write(tmp_buf.data(), tmp_buf.size());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
|
||||||
|
const uint32_t kv_size = kv_self.size;
|
||||||
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
// Write value type
|
||||||
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
|
data_ctx.write(&v_type_i, sizeof(v_type_i));
|
||||||
|
|
||||||
|
// Write element size
|
||||||
|
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
||||||
|
data_ctx.write(&v_size_el, sizeof(v_size_el));
|
||||||
|
|
||||||
|
// For each row, we get the element values of each cell
|
||||||
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||||
|
// Read each range of cells of v_size_el length each into tmp_buf and write out
|
||||||
|
for (const auto & range : cell_ranges) {
|
||||||
|
const size_t range_size = range.second - range.first;
|
||||||
|
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
|
||||||
|
tmp_buf.resize(range_size * v_size_el);
|
||||||
|
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
|
||||||
|
data_ctx.write(tmp_buf.data(), tmp_buf.size());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return data_ctx.get_size_written();
|
return data_ctx.get_size_written();
|
||||||
|
@ -16923,41 +16947,74 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For each layer, read the values for each cell (transposed)
|
if (!kv_self.v_trans) {
|
||||||
for (int il = 0; il < (int)n_layer; ++il) {
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
// Read type of value
|
// Read type of key
|
||||||
int32_t v_type_i_ref;
|
int32_t v_type_i_ref;
|
||||||
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
||||||
inp += sizeof(v_type_i_ref);
|
inp += sizeof(v_type_i_ref);
|
||||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
if (v_type_i != v_type_i_ref) {
|
if (v_type_i != v_type_i_ref) {
|
||||||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read element size of value
|
// Read row size of key
|
||||||
size_t v_size_el_ref;
|
size_t v_size_row_ref;
|
||||||
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
|
memcpy(&v_size_row_ref, inp, sizeof(v_size_row_ref));
|
||||||
inp += sizeof(v_size_el_ref);
|
inp += sizeof(v_size_row_ref);
|
||||||
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
const size_t v_size_row = ggml_row_size(kv_self.v_l[il]->type, n_embd_v_gqa);
|
||||||
if (v_size_el != v_size_el_ref) {
|
if (v_size_row != v_size_row_ref) {
|
||||||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
|
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, v_size_row, v_size_row_ref, il);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell_count) {
|
if (cell_count) {
|
||||||
// For each row in the transposed matrix, read the values for the whole cell range
|
// Read and set the keys for the whole cell range
|
||||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
ggml_backend_tensor_set(kv_self.v_l[il], inp, kv_head * v_size_row, cell_count * v_size_row);
|
||||||
const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
|
inp += cell_count * v_size_row;
|
||||||
ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
|
}
|
||||||
inp += cell_count * v_size_el;
|
}
|
||||||
|
} else {
|
||||||
|
// For each layer, read the values for each cell (transposed)
|
||||||
|
for (int il = 0; il < (int)n_layer; ++il) {
|
||||||
|
// Read type of value
|
||||||
|
int32_t v_type_i_ref;
|
||||||
|
memcpy(&v_type_i_ref, inp, sizeof(v_type_i_ref));
|
||||||
|
inp += sizeof(v_type_i_ref);
|
||||||
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
|
if (v_type_i != v_type_i_ref) {
|
||||||
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read element size of value
|
||||||
|
size_t v_size_el_ref;
|
||||||
|
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
|
||||||
|
inp += sizeof(v_size_el_ref);
|
||||||
|
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
|
||||||
|
if (v_size_el != v_size_el_ref) {
|
||||||
|
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, v_size_el_ref, il);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cell_count) {
|
||||||
|
// For each row in the transposed matrix, read the values for the whole cell range
|
||||||
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||||
|
const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
|
||||||
|
ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
|
||||||
|
inp += cell_count * v_size_el;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t nread = inp - src;
|
const size_t nread = inp - src;
|
||||||
|
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
2
llama.h
2
llama.h
|
@ -526,7 +526,7 @@ extern "C" {
|
||||||
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
|
||||||
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
LLAMA_API int32_t llama_get_kv_cache_used_cells(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Clear the KV cache
|
// Clear the KV cache - both cell info is erased and KV data is zeroed
|
||||||
LLAMA_API void llama_kv_cache_clear(
|
LLAMA_API void llama_kv_cache_clear(
|
||||||
struct llama_context * ctx);
|
struct llama_context * ctx);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue