llama : session saving and reloading for hybrid models
This commit is contained in:
parent
bc320ef66d
commit
fcb889cf7f
2 changed files with 391 additions and 134 deletions
|
@ -38,10 +38,10 @@
|
||||||
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 8
|
#define LLAMA_SESSION_VERSION 9
|
||||||
|
|
||||||
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
#define LLAMA_STATE_SEQ_VERSION 2
|
#define LLAMA_STATE_SEQ_VERSION 3
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
|
|
521
src/llama.cpp
521
src/llama.cpp
|
@ -19839,8 +19839,28 @@ struct llama_data_write {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void write_rs_cache_meta(const llama_rs_cache & rs_self, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges, llama_seq_id seq_id = -1) {
|
||||||
|
|
||||||
|
for (const auto & range : cell_ranges) {
|
||||||
|
for (uint32_t i = range.first; i < range.second; ++i) {
|
||||||
|
const auto & cell = rs_self.cells[i];
|
||||||
|
const llama_pos pos = cell.pos;
|
||||||
|
const uint32_t n_seq_id = seq_id == -1 ? cell.seq_nodes.size() : 0;
|
||||||
|
|
||||||
|
write(&pos, sizeof(pos));
|
||||||
|
write(&n_seq_id, sizeof(n_seq_id));
|
||||||
|
|
||||||
|
if (n_seq_id) {
|
||||||
|
for (auto seq_node : cell.seq_nodes) {
|
||||||
|
write(&seq_node.seq_id, sizeof(seq_node.seq_id));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
|
void write_kv_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
|
||||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
const struct llama_kv_cache & kv_self = ctx->cache.kv;
|
||||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
|
const uint32_t v_trans = kv_self.v_trans ? 1 : 0;
|
||||||
|
@ -19849,12 +19869,10 @@ struct llama_data_write {
|
||||||
write(&v_trans, sizeof(v_trans));
|
write(&v_trans, sizeof(v_trans));
|
||||||
write(&n_layer, sizeof(n_layer));
|
write(&n_layer, sizeof(n_layer));
|
||||||
|
|
||||||
std::vector<uint8_t> tmp_buf;
|
|
||||||
|
|
||||||
// Iterate and write all the keys first, each row is a cell
|
// Iterate and write all the keys first, each row is a cell
|
||||||
// Get whole range at a time
|
// Get whole range at a time
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
|
|
||||||
// Write key type
|
// Write key type
|
||||||
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
|
||||||
|
@ -19874,7 +19892,7 @@ struct llama_data_write {
|
||||||
|
|
||||||
if (!kv_self.v_trans) {
|
if (!kv_self.v_trans) {
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
// Write value type
|
// Write value type
|
||||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
|
@ -19895,7 +19913,7 @@ struct llama_data_write {
|
||||||
// When v is transposed, we also need the element size and get the element ranges from each row
|
// When v is transposed, we also need the element size and get the element ranges from each row
|
||||||
const uint32_t kv_size = kv_self.size;
|
const uint32_t kv_size = kv_self.size;
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
// Write value type
|
// Write value type
|
||||||
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
|
||||||
|
@ -19922,43 +19940,151 @@ struct llama_data_write {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void write_kv_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
void write_rs_cache_data(const struct llama_context * ctx, const std::vector<std::pair<uint32_t, uint32_t>> & cell_ranges) {
|
||||||
const struct llama_kv_cache & kv_self = ctx->kv_self;
|
const struct llama_rs_cache & rs_self = ctx->cache.rs;
|
||||||
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
|
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||||
uint32_t cell_count = 0;
|
|
||||||
|
|
||||||
// Count the number of cells with the specified seq_id
|
const uint32_t n_layer = hparams.n_layer;
|
||||||
// Find all the ranges of cells with this seq id (or all, when -1)
|
|
||||||
uint32_t cell_range_begin = kv_self.size;
|
write(&n_layer, sizeof(n_layer));
|
||||||
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
|
||||||
const auto & cell = kv_self.cells[i];
|
// Iterate and write all recurrent states, each row is a cell
|
||||||
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
// Get whole range at a time
|
||||||
++cell_count;
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
if (cell_range_begin == kv_self.size) {
|
const uint32_t n_embd_r = hparams.n_embd_r(il);
|
||||||
cell_range_begin = i;
|
|
||||||
|
// Write type
|
||||||
|
const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type;
|
||||||
|
write(&r_type_i, sizeof(r_type_i));
|
||||||
|
|
||||||
|
// Write row size
|
||||||
|
const uint64_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r);
|
||||||
|
write(&r_size_row, sizeof(r_size_row));
|
||||||
|
|
||||||
|
// Read each range of cells of r_size length each and write out
|
||||||
|
for (const auto & range : cell_ranges) {
|
||||||
|
const size_t range_size = range.second - range.first;
|
||||||
|
const size_t buf_size = range_size * r_size_row;
|
||||||
|
write_tensor_data(rs_self.r_l[il], range.first * r_size_row, buf_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_s = hparams.n_embd_s(il);
|
||||||
|
|
||||||
|
// Write type
|
||||||
|
const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type;
|
||||||
|
write(&s_type_i, sizeof(s_type_i));
|
||||||
|
|
||||||
|
// Write row size
|
||||||
|
const uint64_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s);
|
||||||
|
write(&s_size_row, sizeof(s_size_row));
|
||||||
|
|
||||||
|
// Read each range of cells of s_size length each and write out
|
||||||
|
for (const auto & range : cell_ranges) {
|
||||||
|
const size_t range_size = range.second - range.first;
|
||||||
|
const size_t buf_size = range_size * s_size_row;
|
||||||
|
write_tensor_data(rs_self.s_l[il], range.first * s_size_row, buf_size);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_cache(const struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
||||||
|
const struct llama_kv_cache & kv_self = ctx->cache.kv;
|
||||||
|
const struct llama_rs_cache & rs_self = ctx->cache.rs;
|
||||||
|
std::vector<std::pair<uint32_t, uint32_t>> kv_cell_ranges; // ranges, from inclusive, to exclusive
|
||||||
|
std::vector<std::pair<uint32_t, uint32_t>> rs_cell_ranges; // ranges, from inclusive, to exclusive
|
||||||
|
uint32_t kv_cell_count = 0;
|
||||||
|
uint32_t rs_cell_count = 0;
|
||||||
|
// Transformer KV cache
|
||||||
|
{
|
||||||
|
// Count the number of cells with the specified seq_id
|
||||||
|
// Find all the ranges of cells with this seq id (or all, when -1)
|
||||||
|
uint32_t cell_range_begin = kv_self.size;
|
||||||
|
for (uint32_t i = 0; i < kv_self.size; ++i) {
|
||||||
|
const auto & cell = kv_self.cells[i];
|
||||||
|
if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) {
|
||||||
|
++kv_cell_count;
|
||||||
|
if (cell_range_begin == kv_self.size) {
|
||||||
|
cell_range_begin = i;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (cell_range_begin != kv_self.size) {
|
||||||
|
kv_cell_ranges.emplace_back(cell_range_begin, i);
|
||||||
|
cell_range_begin = kv_self.size;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
if (cell_range_begin != kv_self.size) {
|
if (cell_range_begin != kv_self.size) {
|
||||||
cell_ranges.emplace_back(cell_range_begin, i);
|
kv_cell_ranges.emplace_back(cell_range_begin, kv_self.size);
|
||||||
cell_range_begin = kv_self.size;
|
}
|
||||||
|
|
||||||
|
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||||
|
uint32_t cell_count_check = 0;
|
||||||
|
for (const auto & range : kv_cell_ranges) {
|
||||||
|
cell_count_check += range.second - range.first;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(kv_cell_count == cell_count_check);
|
||||||
|
}
|
||||||
|
// Recurrent state cache
|
||||||
|
if (seq_id == -1) {
|
||||||
|
// Find all the ranges of cells
|
||||||
|
uint32_t cell_range_begin = rs_self.size;
|
||||||
|
for (uint32_t i = 0; i < rs_self.size; ++i) {
|
||||||
|
const auto & cell = rs_self.cells[i];
|
||||||
|
if (!cell.is_empty()) {
|
||||||
|
++rs_cell_count;
|
||||||
|
if (cell_range_begin == rs_self.size) {
|
||||||
|
cell_range_begin = i;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (cell_range_begin != rs_self.size) {
|
||||||
|
rs_cell_ranges.emplace_back(cell_range_begin, i);
|
||||||
|
cell_range_begin = rs_self.size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (cell_range_begin != rs_self.size) {
|
||||||
|
rs_cell_ranges.emplace_back(cell_range_begin, rs_self.size);
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// Find the cell ranges of the specified seq_id
|
||||||
|
if ((size_t) seq_id < rs_self.seq_tails.size()) {
|
||||||
|
int32_t tail_cell_id = rs_self.seq_tails[seq_id].tail;
|
||||||
|
if (tail_cell_id >= 0) {
|
||||||
|
++rs_cell_count;
|
||||||
|
rs_cell_ranges.emplace_back(tail_cell_id, tail_cell_id + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (cell_range_begin != kv_self.size) {
|
|
||||||
cell_ranges.emplace_back(cell_range_begin, kv_self.size);
|
{
|
||||||
|
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
||||||
|
uint32_t cell_count_check = 0;
|
||||||
|
for (const auto & range : rs_cell_ranges) {
|
||||||
|
cell_count_check += range.second - range.first;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(rs_cell_count == cell_count_check);
|
||||||
}
|
}
|
||||||
|
|
||||||
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
|
write(&kv_cell_count, sizeof(kv_cell_count));
|
||||||
uint32_t cell_count_check = 0;
|
write(&rs_cell_count, sizeof(rs_cell_count));
|
||||||
for (const auto & range : cell_ranges) {
|
|
||||||
cell_count_check += range.second - range.first;
|
if (seq_id == -1) {
|
||||||
|
// write metadata for both when the whole cache needs to be saved
|
||||||
|
write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id);
|
||||||
|
write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id);
|
||||||
|
} else if (kv_cell_count > 0) {
|
||||||
|
write_kv_cache_meta(kv_self, kv_cell_ranges, seq_id);
|
||||||
|
} else {
|
||||||
|
write_rs_cache_meta(rs_self, rs_cell_ranges, seq_id);
|
||||||
|
}
|
||||||
|
if (kv_cell_count > 0) {
|
||||||
|
write_kv_cache_data(ctx, kv_cell_ranges);
|
||||||
|
}
|
||||||
|
if (rs_cell_count > 0) {
|
||||||
|
write_rs_cache_data(ctx, rs_cell_ranges);
|
||||||
}
|
}
|
||||||
GGML_ASSERT(cell_count == cell_count_check);
|
|
||||||
|
|
||||||
write(&cell_count, sizeof(cell_count));
|
|
||||||
|
|
||||||
write_kv_cache_meta(kv_self, cell_ranges, seq_id);
|
|
||||||
write_kv_cache_data(ctx, cell_ranges);
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -20050,108 +20176,98 @@ struct llama_data_read {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
|
bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count) {
|
||||||
struct llama_kv_cache & kv_self = ctx->kv_self;
|
if (cell_count == 0) { return true; }
|
||||||
|
struct llama_past & cache = ctx->cache;
|
||||||
|
struct llama_kv_cache & kv_self = cache.kv;
|
||||||
|
|
||||||
if (dest_seq_id != -1) {
|
// whole KV cache restore
|
||||||
// single sequence
|
|
||||||
|
|
||||||
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
|
if (cell_count > kv_self.size) {
|
||||||
|
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||||
batch.n_tokens = cell_count;
|
llama_kv_cell & cell = kv_self.cells[i];
|
||||||
batch.n_seq_tokens = cell_count;
|
|
||||||
batch.n_seqs = 1;
|
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
llama_pos pos;
|
||||||
llama_pos pos;
|
uint32_t n_seq_id;
|
||||||
uint32_t n_seq_id;
|
|
||||||
|
|
||||||
read_to(&pos, sizeof(pos));
|
read_to(&pos, sizeof(pos));
|
||||||
read_to(&n_seq_id, sizeof(n_seq_id));
|
read_to(&n_seq_id, sizeof(n_seq_id));
|
||||||
|
|
||||||
if (n_seq_id != 0) {
|
cell.pos = pos;
|
||||||
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
|
||||||
|
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||||
|
llama_seq_id seq_id;
|
||||||
|
read_to(&seq_id, sizeof(seq_id));
|
||||||
|
|
||||||
|
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.pos[i] = pos;
|
cell.seq_id.insert(seq_id);
|
||||||
}
|
|
||||||
batch.n_seq_id[0] = 1;
|
|
||||||
batch.seq_id[0] = &dest_seq_id;
|
|
||||||
if (!llama_kv_cache_find_slot(kv_self, batch)) {
|
|
||||||
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
|
||||||
return false;
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
kv_self.head = 0;
|
||||||
// Assume that this is one contiguous block of cells
|
kv_self.used = cell_count;
|
||||||
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
|
|
||||||
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
|
|
||||||
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
|
||||||
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
|
|
||||||
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
|
|
||||||
} else {
|
|
||||||
// whole KV cache restore
|
|
||||||
|
|
||||||
if (cell_count > kv_self.size) {
|
return true;
|
||||||
LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__);
|
}
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_kv_cache_clear(kv_self);
|
bool read_rs_cache_meta(struct llama_context * ctx, uint32_t cell_count) {
|
||||||
|
if (cell_count == 0) { return true; }
|
||||||
|
struct llama_past & cache = ctx->cache;
|
||||||
|
struct llama_rs_cache & rs_self = cache.rs;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
// whole RS cache restore
|
||||||
llama_kv_cell & cell = kv_self.cells[i];
|
|
||||||
|
|
||||||
llama_pos pos;
|
if (cell_count > rs_self.size) {
|
||||||
uint32_t n_seq_id;
|
LLAMA_LOG_ERROR("%s: not enough cells in rs cache\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
read_to(&pos, sizeof(pos));
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||||
read_to(&n_seq_id, sizeof(n_seq_id));
|
llama_rs_cell & cell = rs_self.cells[i];
|
||||||
|
|
||||||
cell.pos = pos;
|
llama_pos pos;
|
||||||
|
uint32_t n_seq_id;
|
||||||
|
|
||||||
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
read_to(&pos, sizeof(pos));
|
||||||
llama_seq_id seq_id;
|
read_to(&n_seq_id, sizeof(n_seq_id));
|
||||||
read_to(&seq_id, sizeof(seq_id));
|
|
||||||
|
|
||||||
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
cell.pos = pos;
|
||||||
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
cell.src = i;
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
cell.seq_id.insert(seq_id);
|
for (uint32_t j = 0; j < n_seq_id; ++j) {
|
||||||
|
llama_seq_id seq_id;
|
||||||
|
read_to(&seq_id, sizeof(seq_id));
|
||||||
|
|
||||||
if (kv_self.recurrent) {
|
if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
|
||||||
int32_t & tail = kv_self.cells[seq_id].tail;
|
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
|
||||||
if (tail != -1) {
|
return false;
|
||||||
LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
tail = i;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
kv_self.head = 0;
|
cell.insert_node(seq_id);
|
||||||
kv_self.used = cell_count;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (kv_self.recurrent) {
|
|
||||||
for (uint32_t i = 0; i < cell_count; ++i) {
|
|
||||||
uint32_t cell_id = kv_self.head + i;
|
|
||||||
// make sure the recurrent states will keep their restored state
|
|
||||||
kv_self.cells[cell_id].src = cell_id;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rs_self.head = 0;
|
||||||
|
rs_self.used = cell_count;
|
||||||
|
|
||||||
|
rs_self.rebuild(/* debug */ false);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
|
bool read_kv_cache_data(struct llama_context * ctx, uint32_t cell_count) {
|
||||||
|
if (cell_count == 0) { return true; }
|
||||||
const struct llama_hparams & hparams = ctx->model.hparams;
|
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||||
struct llama_kv_cache & kv_self = ctx->kv_self;
|
struct llama_kv_cache & kv_self = ctx->cache.kv;
|
||||||
uint32_t v_trans;
|
uint32_t v_trans;
|
||||||
uint32_t n_layer;
|
uint32_t n_layer;
|
||||||
read_to(&v_trans, sizeof(v_trans));
|
read_to(&v_trans, sizeof(v_trans));
|
||||||
|
@ -20172,7 +20288,7 @@ struct llama_data_read {
|
||||||
|
|
||||||
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s();
|
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il);
|
||||||
|
|
||||||
// Read type of key
|
// Read type of key
|
||||||
int32_t k_type_i_ref;
|
int32_t k_type_i_ref;
|
||||||
|
@ -20192,15 +20308,13 @@ struct llama_data_read {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell_count) {
|
// Read and set the keys for the whole cell range
|
||||||
// Read and set the keys for the whole cell range
|
ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
|
||||||
ggml_backend_tensor_set(kv_self.k_l[il], read(cell_count * k_size_row), kv_self.head * k_size_row, cell_count * k_size_row);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!kv_self.v_trans) {
|
if (!kv_self.v_trans) {
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
// Read type of value
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t v_type_i_ref;
|
||||||
|
@ -20220,15 +20334,13 @@ struct llama_data_read {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell_count) {
|
// Read and set the values for the whole cell range
|
||||||
// Read and set the values for the whole cell range
|
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
|
||||||
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_row), kv_self.head * v_size_row, cell_count * v_size_row);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// For each layer, read the values for each cell (transposed)
|
// For each layer, read the values for each cell (transposed)
|
||||||
for (uint32_t il = 0; il < n_layer; ++il) {
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s();
|
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il);
|
||||||
|
|
||||||
// Read type of value
|
// Read type of value
|
||||||
int32_t v_type_i_ref;
|
int32_t v_type_i_ref;
|
||||||
|
@ -20256,29 +20368,174 @@ struct llama_data_read {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (cell_count) {
|
// For each row in the transposed matrix, read the values for the whole cell range
|
||||||
// For each row in the transposed matrix, read the values for the whole cell range
|
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
||||||
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
|
const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
|
||||||
const size_t dst_offset = (kv_self.head + j * kv_self.size) * v_size_el;
|
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
||||||
ggml_backend_tensor_set(kv_self.v_l[il], read(cell_count * v_size_el), dst_offset, cell_count * v_size_el);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
bool read_rs_cache_data(struct llama_context * ctx, uint32_t cell_count) {
|
||||||
uint32_t cell_count;
|
if (cell_count == 0) { return true; }
|
||||||
read_to(&cell_count, sizeof(cell_count));
|
const struct llama_hparams & hparams = ctx->model.hparams;
|
||||||
|
struct llama_rs_cache & rs_self = ctx->cache.rs;
|
||||||
|
uint32_t n_layer;
|
||||||
|
read_to(&n_layer, sizeof(n_layer));
|
||||||
|
|
||||||
bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
|
if (n_layer != hparams.n_layer) {
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
if (cell_count > rs_self.size) {
|
||||||
|
LLAMA_LOG_ERROR("%s: not enough cells in rs cache to restore state (%u > %u)\n", __func__, cell_count, rs_self.size);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For each layer, one row is one cell, read as one contiguous block
|
||||||
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_r = hparams.n_embd_r(il);
|
||||||
|
|
||||||
|
// Read type of key
|
||||||
|
int32_t r_type_i_ref;
|
||||||
|
read_to(&r_type_i_ref, sizeof(r_type_i_ref));
|
||||||
|
const int32_t r_type_i = (int32_t)rs_self.r_l[il]->type;
|
||||||
|
if (r_type_i != r_type_i_ref) {
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, r_type_i, r_type_i_ref, il);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read row size of key
|
||||||
|
uint64_t r_size_row_ref;
|
||||||
|
read_to(&r_size_row_ref, sizeof(r_size_row_ref));
|
||||||
|
const size_t r_size_row = ggml_row_size(rs_self.r_l[il]->type, n_embd_r);
|
||||||
|
if (r_size_row != r_size_row_ref) {
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, r_size_row, (size_t) r_size_row_ref, il);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and set the keys for the whole cell range
|
||||||
|
ggml_backend_tensor_set(rs_self.r_l[il], read(cell_count * r_size_row), rs_self.head * r_size_row, cell_count * r_size_row);
|
||||||
|
}
|
||||||
|
for (uint32_t il = 0; il < n_layer; ++il) {
|
||||||
|
const uint32_t n_embd_s = hparams.n_embd_s(il);
|
||||||
|
|
||||||
|
// Read type of key
|
||||||
|
int32_t s_type_i_ref;
|
||||||
|
read_to(&s_type_i_ref, sizeof(s_type_i_ref));
|
||||||
|
const int32_t s_type_i = (int32_t)rs_self.s_l[il]->type;
|
||||||
|
if (s_type_i != s_type_i_ref) {
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, s_type_i, s_type_i_ref, il);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read row size of key
|
||||||
|
uint64_t s_size_row_ref;
|
||||||
|
read_to(&s_size_row_ref, sizeof(s_size_row_ref));
|
||||||
|
const size_t s_size_row = ggml_row_size(rs_self.s_l[il]->type, n_embd_s);
|
||||||
|
if (s_size_row != s_size_row_ref) {
|
||||||
|
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, s_size_row, (size_t) s_size_row_ref, il);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read and set the keys for the whole cell range
|
||||||
|
ggml_backend_tensor_set(rs_self.s_l[il], read(cell_count * s_size_row), rs_self.head * s_size_row, cell_count * s_size_row);
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool read_cache_seq_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id seq_id = -1) {
|
||||||
|
|
||||||
|
if (seq_id < 0 || seq_id >= llama_n_seq_max(ctx)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: seq_id out of range [0, %d): %d\n", __func__, llama_n_seq_max(ctx), seq_id);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// single sequence
|
||||||
|
|
||||||
|
llama_past & cache = ctx->cache;
|
||||||
|
llama_ubatch batch = ctx->sbatch.reserve_ubatch(cell_count, /* has_embd */ false);
|
||||||
|
batch.n_tokens = cell_count;
|
||||||
|
batch.n_seq_tokens = cell_count;
|
||||||
|
batch.n_seqs = 1;
|
||||||
|
|
||||||
|
for (uint32_t i = 0; i < cell_count; ++i) {
|
||||||
|
llama_pos pos;
|
||||||
|
uint32_t n_seq_id;
|
||||||
|
|
||||||
|
read_to(&pos, sizeof(pos));
|
||||||
|
read_to(&n_seq_id, sizeof(n_seq_id));
|
||||||
|
|
||||||
|
if (n_seq_id != 0) {
|
||||||
|
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.pos[i] = pos;
|
||||||
|
}
|
||||||
|
batch.n_seq_id[0] = 1;
|
||||||
|
batch.seq_id[0] = &seq_id;
|
||||||
|
if (!llama_past_find_slot(cache, batch)) {
|
||||||
|
LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (cache.kv.size > 0) {
|
||||||
|
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
|
||||||
|
// Assume that this is one contiguous block of cells
|
||||||
|
GGML_ASSERT(cache.kv.head + cell_count <= cache.kv.size);
|
||||||
|
GGML_ASSERT(cache.kv.cells[cache.kv.head].pos == batch.pos[0]);
|
||||||
|
GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
|
||||||
|
GGML_ASSERT(cache.kv.cells[cache.kv.head].has_seq_id(seq_id));
|
||||||
|
GGML_ASSERT(cache.kv.cells[cache.kv.head + cell_count - 1].has_seq_id(seq_id));
|
||||||
|
}
|
||||||
|
if (cache.rs.size > 0) {
|
||||||
|
GGML_ASSERT(cache.rs.head + cache.rs.n <= cache.rs.size);
|
||||||
|
GGML_ASSERT(cache.rs.n == 1);
|
||||||
|
GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].pos == batch.pos[cell_count - 1]);
|
||||||
|
GGML_ASSERT(cache.rs.cells[cache.rs.head].has_seq_id(seq_id));
|
||||||
|
GGML_ASSERT(cache.rs.cells[cache.rs.head + cache.rs.n - 1].has_seq_id(seq_id));
|
||||||
|
// Prevent cells from being cleared
|
||||||
|
for (uint32_t i = cache.rs.head; i < cache.rs.head + cache.rs.n; ++i) {
|
||||||
|
cache.rs.cells[i].src = i;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
|
||||||
|
uint32_t kv_cell_count;
|
||||||
|
read_to(&kv_cell_count, sizeof(kv_cell_count));
|
||||||
|
uint32_t rs_cell_count;
|
||||||
|
read_to(&rs_cell_count, sizeof(rs_cell_count));
|
||||||
|
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
if (seq_id == -1) {
|
||||||
|
llama_past_clear(ctx);
|
||||||
|
res = read_kv_cache_meta(ctx, kv_cell_count) && read_rs_cache_meta(ctx, rs_cell_count);
|
||||||
|
} else {
|
||||||
|
llama_past_seq_rm(ctx, seq_id, -1, -1);
|
||||||
|
// Only a single recurrent cell at most,
|
||||||
|
// because otherwise the cells can be shuffled when a slot is allocated
|
||||||
|
if (rs_cell_count > 1) {
|
||||||
|
LLAMA_LOG_ERROR("%s: too many recurrent state cells for single-sequence session\n", __func__);
|
||||||
|
res = false;
|
||||||
|
}
|
||||||
|
res = res && read_cache_seq_meta(ctx, std::max(kv_cell_count, rs_cell_count), seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
res = res && read_kv_cache_data(ctx, kv_cell_count) && read_rs_cache_data(ctx, rs_cell_count);
|
||||||
|
|
||||||
if (!res) {
|
if (!res) {
|
||||||
if (seq_id == -1) {
|
if (seq_id == -1) {
|
||||||
llama_kv_cache_clear(ctx);
|
llama_past_clear(ctx);
|
||||||
} else {
|
} else {
|
||||||
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
|
llama_past_seq_rm(ctx, seq_id, -1, -1);
|
||||||
}
|
}
|
||||||
throw std::runtime_error("failed to restore kv cache");
|
throw std::runtime_error("failed to restore kv cache");
|
||||||
}
|
}
|
||||||
|
@ -20433,7 +20690,7 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
|
||||||
data_ctx.write_logits(ctx);
|
data_ctx.write_logits(ctx);
|
||||||
data_ctx.write_embeddings(ctx);
|
data_ctx.write_embeddings(ctx);
|
||||||
|
|
||||||
data_ctx.write_kv_cache(ctx);
|
data_ctx.write_cache(ctx);
|
||||||
|
|
||||||
return data_ctx.get_size_written();
|
return data_ctx.get_size_written();
|
||||||
}
|
}
|
||||||
|
@ -20473,7 +20730,7 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
|
||||||
data_ctx.read_logits(ctx);
|
data_ctx.read_logits(ctx);
|
||||||
data_ctx.read_embeddings(ctx);
|
data_ctx.read_embeddings(ctx);
|
||||||
|
|
||||||
data_ctx.read_kv_cache(ctx);
|
data_ctx.read_cache(ctx);
|
||||||
|
|
||||||
return data_ctx.get_size_read();
|
return data_ctx.get_size_read();
|
||||||
}
|
}
|
||||||
|
@ -20569,7 +20826,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
|
||||||
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
|
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_write & data_ctx, llama_seq_id seq_id) {
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
data_ctx.write_kv_cache(ctx, seq_id);
|
data_ctx.write_cache(ctx, seq_id);
|
||||||
|
|
||||||
return data_ctx.get_size_written();
|
return data_ctx.get_size_written();
|
||||||
}
|
}
|
||||||
|
@ -20592,7 +20849,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_
|
||||||
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
|
static size_t llama_state_seq_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx, llama_seq_id dest_seq_id) {
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
data_ctx.read_kv_cache(ctx, dest_seq_id);
|
data_ctx.read_cache(ctx, dest_seq_id);
|
||||||
|
|
||||||
return data_ctx.get_size_read();
|
return data_ctx.get_size_read();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue