move sequence state file functionality from server to llama to match session api and add version tags

This commit is contained in:
Jan Boon 2024-03-31 03:26:25 +08:00
parent 129b6ffea6
commit 8af72118ec
3 changed files with 102 additions and 48 deletions

View file

@ -1630,26 +1630,8 @@ struct server_context {
std::string filename = task.data["filename"]; std::string filename = task.data["filename"];
std::string filepath = task.data["filepath"]; std::string filepath = task.data["filepath"];
size_t state_size = llama_state_seq_get_size(ctx, slot->id + 1);
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
size_t nwrite = llama_state_seq_get_data(ctx, state_data.data(), slot->id + 1);
GGML_ASSERT(nwrite <= state_size);
// write the cached token count of the slot->cache_tokens.size() const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
memcpy(state_data.data() + nwrite, &token_count, sizeof(size_t));
nwrite += sizeof(size_t);
// write the cached tokens (loop)
for (size_t i = 0; i < token_count; i++) {
const llama_token token = slot->cache_tokens[i];
memcpy(state_data.data() + nwrite, &token, sizeof(llama_token));
nwrite += sizeof(llama_token);
}
GGML_ASSERT(nwrite <= state_data.size());
std::ofstream outfile(filepath, std::ios::binary);
outfile.write(reinterpret_cast<const char *>(state_data.data()), nwrite);
outfile.close();
const int64_t t_end = ggml_time_us(); const int64_t t_end = ggml_time_us();
const double t_save_ms = (t_end - t_start) / 1000.0; const double t_save_ms = (t_end - t_start) / 1000.0;
@ -1682,39 +1664,16 @@ struct server_context {
std::string filename = task.data["filename"]; std::string filename = task.data["filename"];
std::string filepath = task.data["filepath"]; std::string filepath = task.data["filepath"];
std::ifstream infile(filepath, std::ios::binary);
if (!infile.is_open()) {
send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST);
break;
}
std::vector<uint8_t> state_data((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>()); slot->cache_tokens.resize(slot->n_ctx);
infile.close(); size_t token_count = 0;
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
size_t nread = llama_state_seq_set_data(ctx, state_data.data(), slot->id + 1);
if (nread == 0) { if (nread == 0) {
slot->cache_tokens.resize(0);
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST); send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
break; break;
} }
GGML_ASSERT(nread <= state_data.size());
// restore cached token values
size_t token_count = 0;
if (nread + sizeof(size_t) <= state_data.size()) {
token_count = *reinterpret_cast<size_t *>(state_data.data() + nread);
nread += sizeof(size_t);
}
slot->cache_tokens.resize(token_count); slot->cache_tokens.resize(token_count);
GGML_ASSERT(nread + (token_count * sizeof(llama_token)) <= state_data.size());
// tokens are of type llama_token (an integer)
for (size_t i = 0; i < token_count; i++) {
if (nread + sizeof(llama_token) <= state_data.size()) {
slot->cache_tokens[i] = *reinterpret_cast<llama_token *>(state_data.data() + nread);
nread += sizeof(llama_token);
}
}
GGML_ASSERT(nread <= state_data.size());
const int64_t t_end = ggml_time_us(); const int64_t t_end = ggml_time_us();
const double t_restore_ms = (t_end - t_start) / 1000.0; const double t_restore_ms = (t_end - t_start) / 1000.0;

View file

@ -15133,8 +15133,7 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
return s_total; return s_total;
} }
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) { static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) {
llama_data_buffer_context data_ctx(dst);
const auto & kv_self = ctx->kv_self; const auto & kv_self = ctx->kv_self;
GGML_ASSERT(!kv_self.recurrent); // not implemented GGML_ASSERT(!kv_self.recurrent); // not implemented
@ -15238,6 +15237,11 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama
return data_ctx.get_size_written(); return data_ctx.get_size_written();
} }
size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) {
llama_data_buffer_context data_ctx(dst);
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
}
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) { size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
auto & kv_self = ctx->kv_self; auto & kv_self = ctx->kv_self;
GGML_ASSERT(!kv_self.recurrent); // not implemented GGML_ASSERT(!kv_self.recurrent); // not implemented
@ -15361,6 +15365,78 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
return nread; return nread;
} }
size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
llama_file file(filepath, "wb");
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
file.write_u32(LLAMA_STATE_SEQ_VERSION);
// save the prompt
file.write_u32((uint32_t)n_token_count);
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
// save the context state using stream saving
llama_data_file_context data_ctx(&file);
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
const size_t res = file.tell();
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
return res;
}
static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
llama_file file(filepath, "rb");
// version checks
{
const uint32_t magic = file.read_u32();
const uint32_t version = file.read_u32();
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
LLAMA_LOG_ERROR("%s : unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
return 0;
}
}
// load the prompt
{
const uint32_t n_token_count = file.read_u32();
if (n_token_count > n_token_capacity) {
LLAMA_LOG_ERROR("%s : token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
return 0;
}
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
*n_token_count_out = n_token_count;
}
// restore the context state
{
const size_t state_size = file.size - file.tell();
std::vector<uint8_t> state_data(state_size);
file.read_raw(state_data.data(), state_size);
const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id);
if (!nread) {
LLAMA_LOG_ERROR("%s : failed to restore sequence state\n", __func__);
return 0;
}
GGML_ASSERT(nread <= state_size);
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
}
return file.tell();
}
size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
try {
return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what());
return false;
}
}
void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) { void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
ctx->cparams.n_threads = n_threads; ctx->cparams.n_threads = n_threads;
ctx->cparams.n_threads_batch = n_threads_batch; ctx->cparams.n_threads_batch = n_threads_batch;

19
llama.h
View file

@ -37,10 +37,14 @@
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN #define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
#define LLAMA_SESSION_VERSION 5 #define LLAMA_SESSION_VERSION 5
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
#define LLAMA_STATE_SEQ_VERSION 1
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif
@ -667,6 +671,21 @@ extern "C" {
const uint8_t * src, const uint8_t * src,
llama_seq_id dest_seq_id); llama_seq_id dest_seq_id);
LLAMA_API size_t llama_state_seq_save_file(
struct llama_context * ctx,
const char * filepath,
llama_seq_id seq_id,
const llama_token * tokens,
size_t n_token_count);
LLAMA_API size_t llama_state_seq_load_file(
struct llama_context * ctx,
const char * filepath,
llama_seq_id dest_seq_id,
llama_token * tokens_out,
size_t n_token_capacity,
size_t * n_token_count_out);
// //
// Decoding // Decoding
// //