llama : save and restore kv cache for single seq id

This commit is contained in:
Jan Boon 2024-03-27 16:56:35 +08:00
parent e82f9e2b83
commit 662aaea8c9
No known key found for this signature in database
GPG key ID: 9873C4D40BB479BC
3 changed files with 489 additions and 1 deletions

View file

@ -61,7 +61,10 @@ enum server_task_type {
SERVER_TASK_TYPE_COMPLETION,
SERVER_TASK_TYPE_CANCEL,
SERVER_TASK_TYPE_NEXT_RESPONSE,
SERVER_TASK_TYPE_METRICS
SERVER_TASK_TYPE_METRICS,
SERVER_TASK_TYPE_SLOT_SAVE,
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
};
struct server_task {
@ -1612,6 +1615,142 @@ struct server_context {
}
queue_results.send(res);
} break;
case SERVER_TASK_TYPE_SLOT_SAVE:
{
int id_slot = task.data["id_slot"];
server_slot * slot = get_slot(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
const size_t token_count = slot->cache_tokens.size();
const int64_t t_start = ggml_time_us();
std::string filename = task.data["filename"];
size_t state_size = llama_get_seq_size(ctx, slot->id + 1);
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
size_t nwrite = llama_copy_seq_data(ctx, state_data.data(), slot->id + 1);
GGML_ASSERT(nwrite <= state_size);
// write the cached token count of the slot->cache_tokens.size()
memcpy(state_data.data() + nwrite, &token_count, sizeof(size_t));
nwrite += sizeof(size_t);
// write the cached tokens (loop)
for (size_t i = 0; i < token_count; i++) {
const llama_token token = slot->cache_tokens[i];
memcpy(state_data.data() + nwrite, &token, sizeof(llama_token));
nwrite += sizeof(llama_token);
}
GGML_ASSERT(nwrite <= state_data.size());
std::ofstream outfile(filename, std::ios::binary);
outfile.write(reinterpret_cast<const char *>(state_data.data()), nwrite);
outfile.close();
const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0;
server_task_result result;
result.id = task.id;
result.stop = true;
result.error = false;
result.data = json {
{ "id_slot", id_slot },
{ "filename", filename },
{ "n_saved", token_count }, // tokens saved
{ "n_written", nwrite }, // bytes written
{ "timings", {
{ "save_ms", t_save_ms }
} }
};
queue_results.send(result);
} break;
case SERVER_TASK_TYPE_SLOT_RESTORE:
{
int id_slot = task.data["id_slot"];
server_slot * slot = get_slot(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
const int64_t t_start = ggml_time_us();
std::string filename = task.data["filename"]; // TODO: restrict to files in path specified in server params?
std::ifstream infile(filename, std::ios::binary);
if (!infile.is_open()) {
send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST);
break;
}
std::vector<uint8_t> state_data((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
infile.close();
size_t nread = llama_set_seq_data(ctx, state_data.data(), slot->id + 1);
GGML_ASSERT(nread <= state_data.size());
// restore cached token values
size_t token_count = 0;
if (nread + sizeof(size_t) <= state_data.size()) {
token_count = *reinterpret_cast<size_t *>(state_data.data() + nread);
nread += sizeof(size_t);
}
slot->cache_tokens.resize(token_count);
GGML_ASSERT(nread + (token_count * sizeof(llama_token)) <= state_data.size());
// tokens are of type llama_token (an integer)
for (size_t i = 0; i < token_count; i++) {
if (nread + sizeof(llama_token) <= state_data.size()) {
slot->cache_tokens[i] = *reinterpret_cast<llama_token *>(state_data.data() + nread);
nread += sizeof(llama_token);
}
}
GGML_ASSERT(nread <= state_data.size());
const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0;
server_task_result result;
result.id = task.id;
result.stop = true;
result.error = false;
result.data = json {
{ "id_slot", id_slot },
{ "filename", filename },
{ "n_restored", token_count }, // tokens restored
{ "n_read", nread }, // bytes read
{ "timings", {
{ "restore_ms", t_restore_ms }
} }
};
queue_results.send(result);
} break;
case SERVER_TASK_TYPE_SLOT_ERASE:
{
int id_slot = task.data["id_slot"];
server_slot * slot = get_slot(id_slot);
if (slot == nullptr) {
send_error(task, "Invalid slot ID", ERROR_TYPE_INVALID_REQUEST);
break;
}
// Erase token cache
const size_t n_erased = slot->cache_tokens.size();
llama_kv_cache_seq_rm(ctx, slot->id + 1, -1, -1);
slot->cache_tokens.clear();
server_task_result result;
result.id = task.id;
result.stop = true;
result.error = false;
result.data = json {
{ "id_slot", id_slot },
{ "n_erased", n_erased }
};
queue_results.send(result);
} break;
}
}
@ -3157,6 +3296,85 @@ int main(int argc, char ** argv) {
res.status = 200; // HTTP OK
};
const auto handle_slot_save = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json request_data = json::parse(req.body);
int id_slot = request_data["id_slot"];
std::string filename = request_data["filename"];
server_task task;
task.type = SERVER_TASK_TYPE_SLOT_SAVE;
task.data = {
{ "id_slot", id_slot },
{ "filename", filename },
};
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
res_error(res, result.data);
} else {
res.set_content(result.data.dump(), "application/json");
}
};
const auto handle_slot_restore = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json request_data = json::parse(req.body);
int id_slot = request_data["id_slot"];
std::string filename = request_data["filename"];
server_task task;
task.type = SERVER_TASK_TYPE_SLOT_RESTORE;
task.data = {
{ "id_slot", id_slot },
{ "filename", filename },
};
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
res_error(res, result.data);
} else {
res.set_content(result.data.dump(), "application/json");
}
};
const auto handle_slot_erase = [&ctx_server, &res_error](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json request_data = json::parse(req.body);
int id_slot = request_data["id_slot"];
server_task task;
task.type = SERVER_TASK_TYPE_SLOT_ERASE;
task.data = {
{ "id_slot", id_slot },
};
const int id_task = ctx_server.queue_tasks.post(task);
ctx_server.queue_results.add_waiting_task_id(id_task);
server_task_result result = ctx_server.queue_results.recv(id_task);
ctx_server.queue_results.remove_waiting_task_id(id_task);
if (result.error) {
res_error(res, result.data);
} else {
res.set_content(result.data, "application/json");
}
};
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
json data = {
@ -3519,6 +3737,9 @@ int main(int argc, char ** argv) {
svr->Post("/v1/embeddings", handle_embeddings);
svr->Post("/tokenize", handle_tokenize);
svr->Post("/detokenize", handle_detokenize);
svr->Post("/slot/save", handle_slot_save);
svr->Post("/slot/restore", handle_slot_restore);
svr->Post("/slot/erase", handle_slot_erase);
//
// Start the server

253
llama.cpp
View file

@ -15059,6 +15059,259 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
return true;
}
size_t llama_get_seq_size(struct llama_context* ctx, llama_seq_id seq_id) {
// save the size of size_t as a uint32_t for safety check
const size_t size_t_size_size = sizeof(uint32_t);
// other values
const size_t s_cell_count_size = sizeof(uint32_t);
const size_t s_layer_count_size = sizeof(uint32_t);
const size_t n_embd_v_gqa_size = sizeof(uint32_t);
size_t s_cell_count = 0;
size_t s_cell_data_size = 0;
const auto& kv_self = ctx->kv_self;
const auto& hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
for (uint32_t i = 0; i < kv_self.size; ++i) {
const auto& cell = kv_self.cells[i];
if (cell.seq_id.count(seq_id) > 0) {
++s_cell_count;
s_cell_data_size += sizeof(llama_pos);
}
}
for (int il = 0; il < (int)n_layer; ++il) {
// k_size_row and v_size_el values of layer
s_cell_data_size += sizeof(size_t) * 2;
// keys
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
s_cell_data_size += k_size_row * s_cell_count;
// values (transposed)
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
s_cell_data_size += v_size_el * s_cell_count * n_embd_v_gqa;
}
const size_t s_total = (
size_t_size_size +
s_cell_count_size +
s_layer_count_size +
n_embd_v_gqa_size +
s_cell_data_size
);
return s_total;
}
size_t llama_copy_seq_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) {
llama_data_buffer_context data_ctx(dst);
// Save the size of size_t as a uint32_t for safety check
const uint32_t size_t_size = sizeof(size_t);
data_ctx.write(&size_t_size, sizeof(size_t_size));
const auto& kv_self = ctx->kv_self;
std::vector<std::pair<uint32_t, uint32_t>> cell_ranges; // ranges, from inclusive, to exclusive
uint32_t cell_count = 0;
// Count the number of cells with the specified seq_id
// Find all the ranges of cells with this seq id
{
uint32_t cell_range_begin = kv_self.size;
for (uint32_t i = 0; i < kv_self.size; ++i) {
const auto& cell = kv_self.cells[i];
if (cell.has_seq_id(seq_id)) {
++cell_count;
if (cell_range_begin == kv_self.size) {
cell_range_begin = i;
}
}
else {
if (cell_range_begin != kv_self.size) {
cell_ranges.push_back({ cell_range_begin, i });
cell_range_begin = kv_self.size;
}
}
}
if (cell_range_begin != kv_self.size) {
cell_ranges.push_back({ cell_range_begin, kv_self.size });
}
// DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count
uint32_t cell_count_check = 0;
for (const auto& range : cell_ranges) {
cell_count_check += range.second - range.first;
}
GGML_ASSERT(cell_count == cell_count_check);
}
// Write the cell count
data_ctx.write(&cell_count, sizeof(cell_count));
const auto & hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
// Write the layer count
data_ctx.write(&n_layer, sizeof(n_layer));
// Write n_embd_v_gqa
data_ctx.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa));
// Iterate the ranges and write all the pos (this is the token position in the prompt)
for (const auto & range : cell_ranges) {
for (uint32_t i = range.first; i < range.second; ++i) {
const auto & cell = kv_self.cells[i];
data_ctx.write(&cell.pos, sizeof(cell.pos));
}
}
// Iterate and write all the keys first, each row is a cell
// Get whole range at a time
std::vector<uint8_t> tmp_buf;
for (int il = 0; il < (int)n_layer; ++il) {
// Write row size of key
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
data_ctx.write(&k_size_row, sizeof(k_size_row));
// Read each range of cells of k_size length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
tmp_buf.resize(range_size * k_size_row);
ggml_backend_tensor_get(kv_self.k_l[il], tmp_buf.data(), range.first * k_size_row, range_size * k_size_row);
data_ctx.write(&tmp_buf[0], tmp_buf.size());
}
}
// For the values, they are transposed, so we also need the element size and get the element ranges from each row
const uint32_t kv_size = kv_self.size;
for (int il = 0; il < (int)n_layer; ++il) {
// Write element size
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
data_ctx.write(&v_size_el, sizeof(v_size_el));
// For each row, we get the element values of each cell
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
// Read each range of cells of v_size_el length each into tmp_buf and write out
for (const auto & range : cell_ranges) {
const size_t range_size = range.second - range.first;
const size_t src_offset = (range.first + j * kv_size) * v_size_el;
tmp_buf.resize(range_size * v_size_el);
ggml_backend_tensor_get(kv_self.v_l[il], tmp_buf.data(), src_offset, tmp_buf.size());
data_ctx.write(&tmp_buf[0], tmp_buf.size());
}
}
}
return data_ctx.get_size_written();
}
size_t llama_set_seq_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
auto & kv_self = ctx->kv_self;
// Wipe the slot
llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
const uint8_t * inp = src;
// Read size of size_t
uint32_t size_t_size;
memcpy(&size_t_size, inp, sizeof(size_t_size));
inp += sizeof(size_t_size);
GGML_ASSERT(size_t_size == sizeof(size_t));
// Read the cell count
uint32_t cell_count;
memcpy(&cell_count, inp, sizeof(cell_count));
inp += sizeof(cell_count);
// Read the layer count
uint32_t n_layer_ref;
memcpy(&n_layer_ref, inp, sizeof(n_layer_ref));
inp += sizeof(n_layer_ref);
// Read n_embd_v_gqa
uint32_t n_embd_v_gqa_ref;
memcpy(&n_embd_v_gqa_ref, inp, sizeof(n_embd_v_gqa_ref));
inp += sizeof(n_embd_v_gqa_ref);
// Allocate the new cells for the slot
llama_batch batch = llama_batch_init(cell_count, 0, 1);
batch.n_tokens = cell_count;
for (uint32_t i = 0; i < cell_count; ++i) {
llama_pos pos;
memcpy(&pos, inp, sizeof(pos));
inp += sizeof(pos);
batch.pos[i] = pos;
batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = dest_seq_id;
}
llama_kv_cache_find_slot(kv_self, batch);
// DEBUG CHECK: kv_self.head should be our first cell, kv_self.head + cell_count - 1 should be our last cell (verify seq_id and pos values)
// Assume that this is one contiguous block of cells
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
const auto& hparams = ctx->model.hparams;
const uint32_t n_layer = hparams.n_layer;
const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa() + hparams.n_embd_k_s();
const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa() + hparams.n_embd_v_s();
const uint32_t kv_size = kv_self.size;
const uint32_t kv_head = kv_self.head;
GGML_ASSERT(n_layer == n_layer_ref);
GGML_ASSERT(n_embd_v_gqa == n_embd_v_gqa_ref);
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous blo
for (int il = 0; il < (int)n_layer; ++il) {
// Read row size of key
size_t k_size_row_ref;
memcpy(&k_size_row_ref, inp, sizeof(k_size_row_ref));
inp += sizeof(k_size_row_ref);
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
GGML_ASSERT(k_size_row == k_size_row_ref);
// Read and set the keys for the whole cell range
ggml_backend_tensor_set(kv_self.k_l[il], inp, kv_head * k_size_row, cell_count * k_size_row);
inp += cell_count * k_size_row;
}
// For each layer, read the values for each cell (transposed)
for (int il = 0; il < (int)n_layer; ++il) {
// Read element size of value
size_t v_size_el_ref;
memcpy(&v_size_el_ref, inp, sizeof(v_size_el_ref));
inp += sizeof(v_size_el_ref);
const size_t v_size_el = ggml_type_size(kv_self.v_l[il]->type);
GGML_ASSERT(v_size_el == v_size_el_ref);
// For each row in the transposed matrix, read the values for the whole cell range
for (uint32_t j = 0; j < n_embd_v_gqa; ++j) {
const size_t dst_offset = (kv_head + j * kv_size) * v_size_el;
ggml_backend_tensor_set(kv_self.v_l[il], inp, dst_offset, cell_count * v_size_el);
inp += cell_count * v_size_el;
}
}
// Cleanup
llama_batch_free(batch);
const size_t nread = inp - src;
return nread;
}
void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
ctx->cparams.n_threads = n_threads;
ctx->cparams.n_threads_batch = n_threads_batch;

14
llama.h
View file

@ -623,6 +623,20 @@ extern "C" {
const llama_token * tokens,
size_t n_token_count);
LLAMA_API size_t llama_get_seq_size(
struct llama_context * ctx,
llama_seq_id seq_id);
LLAMA_API size_t llama_copy_seq_data(
struct llama_context * ctx,
uint8_t * dst,
llama_seq_id seq_id);
LLAMA_API size_t llama_set_seq_data(
struct llama_context * ctx,
const uint8_t * src,
llama_seq_id dest_seq_id);
//
// Decoding
//