move sequence state file functionality from server to llama to match session api and add version tags
This commit is contained in:
parent
129b6ffea6
commit
8af72118ec
3 changed files with 102 additions and 48 deletions
|
@ -1630,26 +1630,8 @@ struct server_context {
|
||||||
|
|
||||||
std::string filename = task.data["filename"];
|
std::string filename = task.data["filename"];
|
||||||
std::string filepath = task.data["filepath"];
|
std::string filepath = task.data["filepath"];
|
||||||
size_t state_size = llama_state_seq_get_size(ctx, slot->id + 1);
|
|
||||||
std::vector<uint8_t> state_data(state_size + sizeof(size_t) + token_count * sizeof(llama_token));
|
|
||||||
size_t nwrite = llama_state_seq_get_data(ctx, state_data.data(), slot->id + 1);
|
|
||||||
GGML_ASSERT(nwrite <= state_size);
|
|
||||||
|
|
||||||
// write the cached token count of the slot->cache_tokens.size()
|
const size_t nwrite = llama_state_seq_save_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), token_count);
|
||||||
memcpy(state_data.data() + nwrite, &token_count, sizeof(size_t));
|
|
||||||
nwrite += sizeof(size_t);
|
|
||||||
|
|
||||||
// write the cached tokens (loop)
|
|
||||||
for (size_t i = 0; i < token_count; i++) {
|
|
||||||
const llama_token token = slot->cache_tokens[i];
|
|
||||||
memcpy(state_data.data() + nwrite, &token, sizeof(llama_token));
|
|
||||||
nwrite += sizeof(llama_token);
|
|
||||||
}
|
|
||||||
GGML_ASSERT(nwrite <= state_data.size());
|
|
||||||
|
|
||||||
std::ofstream outfile(filepath, std::ios::binary);
|
|
||||||
outfile.write(reinterpret_cast<const char *>(state_data.data()), nwrite);
|
|
||||||
outfile.close();
|
|
||||||
|
|
||||||
const int64_t t_end = ggml_time_us();
|
const int64_t t_end = ggml_time_us();
|
||||||
const double t_save_ms = (t_end - t_start) / 1000.0;
|
const double t_save_ms = (t_end - t_start) / 1000.0;
|
||||||
|
@ -1682,39 +1664,16 @@ struct server_context {
|
||||||
|
|
||||||
std::string filename = task.data["filename"];
|
std::string filename = task.data["filename"];
|
||||||
std::string filepath = task.data["filepath"];
|
std::string filepath = task.data["filepath"];
|
||||||
std::ifstream infile(filepath, std::ios::binary);
|
|
||||||
if (!infile.is_open()) {
|
|
||||||
send_error(task, "Failed to open file", ERROR_TYPE_INVALID_REQUEST);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::vector<uint8_t> state_data((std::istreambuf_iterator<char>(infile)), std::istreambuf_iterator<char>());
|
slot->cache_tokens.resize(slot->n_ctx);
|
||||||
infile.close();
|
size_t token_count = 0;
|
||||||
|
size_t nread = llama_state_seq_load_file(ctx, filepath.c_str(), slot->id + 1, slot->cache_tokens.data(), slot->cache_tokens.size(), &token_count);
|
||||||
size_t nread = llama_state_seq_set_data(ctx, state_data.data(), slot->id + 1);
|
|
||||||
if (nread == 0) {
|
if (nread == 0) {
|
||||||
|
slot->cache_tokens.resize(0);
|
||||||
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
send_error(task, "Unable to restore slot, no available space in KV cache or invalid slot save file", ERROR_TYPE_INVALID_REQUEST);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
GGML_ASSERT(nread <= state_data.size());
|
|
||||||
|
|
||||||
// restore cached token values
|
|
||||||
size_t token_count = 0;
|
|
||||||
if (nread + sizeof(size_t) <= state_data.size()) {
|
|
||||||
token_count = *reinterpret_cast<size_t *>(state_data.data() + nread);
|
|
||||||
nread += sizeof(size_t);
|
|
||||||
}
|
|
||||||
slot->cache_tokens.resize(token_count);
|
slot->cache_tokens.resize(token_count);
|
||||||
GGML_ASSERT(nread + (token_count * sizeof(llama_token)) <= state_data.size());
|
|
||||||
|
|
||||||
// tokens are of type llama_token (an integer)
|
|
||||||
for (size_t i = 0; i < token_count; i++) {
|
|
||||||
if (nread + sizeof(llama_token) <= state_data.size()) {
|
|
||||||
slot->cache_tokens[i] = *reinterpret_cast<llama_token *>(state_data.data() + nread);
|
|
||||||
nread += sizeof(llama_token);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
GGML_ASSERT(nread <= state_data.size());
|
|
||||||
|
|
||||||
const int64_t t_end = ggml_time_us();
|
const int64_t t_end = ggml_time_us();
|
||||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||||
|
|
80
llama.cpp
80
llama.cpp
|
@ -15133,8 +15133,7 @@ size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id)
|
||||||
return s_total;
|
return s_total;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama_seq_id seq_id) {
|
static size_t llama_state_seq_get_data_internal(struct llama_context * ctx, llama_data_context & data_ctx, llama_seq_id seq_id) {
|
||||||
llama_data_buffer_context data_ctx(dst);
|
|
||||||
const auto & kv_self = ctx->kv_self;
|
const auto & kv_self = ctx->kv_self;
|
||||||
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
||||||
|
|
||||||
|
@ -15238,6 +15237,11 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, llama
|
||||||
return data_ctx.get_size_written();
|
return data_ctx.get_size_written();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_get_data(struct llama_context* ctx, uint8_t* dst, llama_seq_id seq_id) {
|
||||||
|
llama_data_buffer_context data_ctx(dst);
|
||||||
|
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||||
|
}
|
||||||
|
|
||||||
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
|
size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, llama_seq_id dest_seq_id) {
|
||||||
auto & kv_self = ctx->kv_self;
|
auto & kv_self = ctx->kv_self;
|
||||||
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
GGML_ASSERT(!kv_self.recurrent); // not implemented
|
||||||
|
@ -15361,6 +15365,78 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
|
||||||
|
llama_file file(filepath, "wb");
|
||||||
|
|
||||||
|
file.write_u32(LLAMA_STATE_SEQ_MAGIC);
|
||||||
|
file.write_u32(LLAMA_STATE_SEQ_VERSION);
|
||||||
|
|
||||||
|
// save the prompt
|
||||||
|
file.write_u32((uint32_t)n_token_count);
|
||||||
|
file.write_raw(tokens, sizeof(llama_token) * n_token_count);
|
||||||
|
|
||||||
|
// save the context state using stream saving
|
||||||
|
llama_data_file_context data_ctx(&file);
|
||||||
|
llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
|
||||||
|
|
||||||
|
const size_t res = file.tell();
|
||||||
|
GGML_ASSERT(res == sizeof(uint32_t) * 3 + sizeof(llama_token) * n_token_count + data_ctx.get_size_written());
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||||
|
llama_file file(filepath, "rb");
|
||||||
|
|
||||||
|
// version checks
|
||||||
|
{
|
||||||
|
const uint32_t magic = file.read_u32();
|
||||||
|
const uint32_t version = file.read_u32();
|
||||||
|
|
||||||
|
if (magic != LLAMA_STATE_SEQ_MAGIC || version != LLAMA_STATE_SEQ_VERSION) {
|
||||||
|
LLAMA_LOG_ERROR("%s : unknown (magic, version) for sequence state file: %08x, %08x\n", __func__, magic, version);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// load the prompt
|
||||||
|
{
|
||||||
|
const uint32_t n_token_count = file.read_u32();
|
||||||
|
|
||||||
|
if (n_token_count > n_token_capacity) {
|
||||||
|
LLAMA_LOG_ERROR("%s : token count in sequence state file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
||||||
|
*n_token_count_out = n_token_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
// restore the context state
|
||||||
|
{
|
||||||
|
const size_t state_size = file.size - file.tell();
|
||||||
|
std::vector<uint8_t> state_data(state_size);
|
||||||
|
file.read_raw(state_data.data(), state_size);
|
||||||
|
const size_t nread = llama_state_seq_set_data(ctx, state_data.data(), dest_seq_id);
|
||||||
|
if (!nread) {
|
||||||
|
LLAMA_LOG_ERROR("%s : failed to restore sequence state\n", __func__);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
GGML_ASSERT(nread <= state_size);
|
||||||
|
GGML_ASSERT(nread + sizeof(uint32_t) * 3 + sizeof(llama_token) * *n_token_count_out == file.tell());
|
||||||
|
}
|
||||||
|
|
||||||
|
return file.tell();
|
||||||
|
}
|
||||||
|
|
||||||
|
size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||||
|
try {
|
||||||
|
return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
|
||||||
|
} catch (const std::exception & err) {
|
||||||
|
LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what());
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
|
void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch) {
|
||||||
ctx->cparams.n_threads = n_threads;
|
ctx->cparams.n_threads = n_threads;
|
||||||
ctx->cparams.n_threads_batch = n_threads_batch;
|
ctx->cparams.n_threads_batch = n_threads_batch;
|
||||||
|
|
19
llama.h
19
llama.h
|
@ -37,10 +37,14 @@
|
||||||
|
|
||||||
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
|
||||||
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
|
||||||
|
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'
|
||||||
|
|
||||||
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
#define LLAMA_SESSION_MAGIC LLAMA_FILE_MAGIC_GGSN
|
||||||
#define LLAMA_SESSION_VERSION 5
|
#define LLAMA_SESSION_VERSION 5
|
||||||
|
|
||||||
|
#define LLAMA_STATE_SEQ_MAGIC LLAMA_FILE_MAGIC_GGSQ
|
||||||
|
#define LLAMA_STATE_SEQ_VERSION 1
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
extern "C" {
|
extern "C" {
|
||||||
#endif
|
#endif
|
||||||
|
@ -667,6 +671,21 @@ extern "C" {
|
||||||
const uint8_t * src,
|
const uint8_t * src,
|
||||||
llama_seq_id dest_seq_id);
|
llama_seq_id dest_seq_id);
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_state_seq_save_file(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const char * filepath,
|
||||||
|
llama_seq_id seq_id,
|
||||||
|
const llama_token * tokens,
|
||||||
|
size_t n_token_count);
|
||||||
|
|
||||||
|
LLAMA_API size_t llama_state_seq_load_file(
|
||||||
|
struct llama_context * ctx,
|
||||||
|
const char * filepath,
|
||||||
|
llama_seq_id dest_seq_id,
|
||||||
|
llama_token * tokens_out,
|
||||||
|
size_t n_token_capacity,
|
||||||
|
size_t * n_token_count_out);
|
||||||
|
|
||||||
//
|
//
|
||||||
// Decoding
|
// Decoding
|
||||||
//
|
//
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue