rename state get set functions

This commit is contained in:
Jan Boon 2024-03-28 22:19:57 +08:00
parent c4443d7ad4
commit 4d5356bbbb
4 changed files with 29 additions and 29 deletions

View file

@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
// The file exists and is not empty // The file exists and is not empty
session_tokens.resize(n_ctx); session_tokens.resize(n_ctx);
size_t n_token_count_out = 0; size_t n_token_count_out = 0;
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) { if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str()); LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
return 1; return 1;
} }
@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
// optionally save the session on first sample (for faster prompt loading next time) // optionally save the session on first sample (for faster prompt loading next time)
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) { if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
need_to_save_session = false; need_to_save_session = false;
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
LOG("saved session to %s\n", path_session.c_str()); LOG("saved session to %s\n", path_session.c_str());
} }
@ -935,7 +935,7 @@ int main(int argc, char ** argv) {
if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) { if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str()); LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
} }
llama_print_timings(ctx); llama_print_timings(ctx);

View file

@ -45,8 +45,8 @@ int main(int argc, char ** argv) {
// save state (rng, logits, embedding and kv_cache) to file // save state (rng, logits, embedding and kv_cache) to file
{ {
std::vector<uint8_t> state_mem(llama_get_state_size(ctx)); std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
const size_t written = llama_copy_state_data(ctx, state_mem.data()); const size_t written = llama_state_get_data(ctx, state_mem.data());
FILE *fp_write = fopen("dump_state.bin", "wb"); FILE *fp_write = fopen("dump_state.bin", "wb");
fwrite(state_mem.data(), 1, written, fp_write); fwrite(state_mem.data(), 1, written, fp_write);
@ -98,13 +98,13 @@ int main(int argc, char ** argv) {
// load state (rng, logits, embedding and kv_cache) from file // load state (rng, logits, embedding and kv_cache) from file
{ {
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2)); std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
FILE * fp_read = fopen("dump_state.bin", "rb"); FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read); fclose(fp_read);
if (read != llama_set_state_data(ctx2, state_mem.data())) { if (read != llama_state_set_data(ctx2, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__); fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2); llama_free(ctx2);
llama_free_model(model); llama_free_model(model);
@ -158,13 +158,13 @@ int main(int argc, char ** argv) {
// load state (rng, logits, embedding and kv_cache) from file // load state (rng, logits, embedding and kv_cache) from file
{ {
std::vector<uint8_t> state_mem(llama_get_state_size(ctx3)); std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
FILE * fp_read = fopen("dump_state.bin", "rb"); FILE * fp_read = fopen("dump_state.bin", "rb");
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
fclose(fp_read); fclose(fp_read);
if (read != llama_set_state_data(ctx3, state_mem.data())) { if (read != llama_state_set_data(ctx3, state_mem.data())) {
fprintf(stderr, "\n%s : failed to read state\n", __func__); fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx3); llama_free(ctx3);
llama_free_model(model); llama_free_model(model);

View file

@ -14570,7 +14570,7 @@ void llama_kv_cache_update(struct llama_context * ctx) {
// Returns the *maximum* size of the state // Returns the *maximum* size of the state
size_t llama_get_state_size(const struct llama_context * ctx) { size_t llama_state_get_size(const struct llama_context * ctx) {
const auto & cparams = ctx->cparams; const auto & cparams = ctx->cparams;
const auto & hparams = ctx->model.hparams; const auto & hparams = ctx->model.hparams;
@ -14658,15 +14658,15 @@ struct llama_data_file_context : llama_data_context {
* file context: * file context:
* llama_file file("/path", "wb"); * llama_file file("/path", "wb");
* llama_data_file_context data_ctx(&file); * llama_data_file_context data_ctx(&file);
* llama_copy_state_data(ctx, &data_ctx); * llama_state_get_data(ctx, &data_ctx);
* *
* buffer context: * buffer context:
* std::vector<uint8_t> buf(max_size, 0); * std::vector<uint8_t> buf(max_size, 0);
* llama_data_buffer_context data_ctx(&buf.data()); * llama_data_buffer_context data_ctx(&buf.data());
* llama_copy_state_data(ctx, &data_ctx); * llama_state_get_data(ctx, &data_ctx);
* *
*/ */
static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) { static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
// copy rng // copy rng
{ {
std::ostringstream rng_ss; std::ostringstream rng_ss;
@ -14810,15 +14810,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
} }
} }
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) {
llama_data_buffer_context data_ctx(dst); llama_data_buffer_context data_ctx(dst);
llama_copy_state_data_internal(ctx, &data_ctx); llama_state_get_data_internal(ctx, &data_ctx);
return data_ctx.get_size_written(); return data_ctx.get_size_written();
} }
// Sets the state reading from the specified source address // Sets the state reading from the specified source address
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
const uint8_t * inp = src; const uint8_t * inp = src;
// set rng // set rng
@ -14970,14 +14970,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
} }
const size_t nread = inp - src; const size_t nread = inp - src;
const size_t max_size = llama_get_state_size(ctx); const size_t max_size = llama_state_get_size(ctx);
GGML_ASSERT(nread <= max_size); GGML_ASSERT(nread <= max_size);
return nread; return nread;
} }
static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
llama_file file(path_session, "rb"); llama_file file(path_session, "rb");
// sanity checks // sanity checks
@ -15015,7 +15015,7 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c
// restore the context state // restore the context state
{ {
const size_t n_state_size_cur = file.size - file.tell(); const size_t n_state_size_cur = file.size - file.tell();
const size_t n_state_size_max = llama_get_state_size(ctx); const size_t n_state_size_max = llama_state_get_size(ctx);
if (n_state_size_cur > n_state_size_max) { if (n_state_size_cur > n_state_size_max) {
LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
@ -15025,22 +15025,22 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c
std::vector<uint8_t> state_data(n_state_size_max); std::vector<uint8_t> state_data(n_state_size_max);
file.read_raw(state_data.data(), n_state_size_cur); file.read_raw(state_data.data(), n_state_size_cur);
llama_set_state_data(ctx, state_data.data()); llama_state_set_data(ctx, state_data.data());
} }
return true; return true;
} }
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
try { try {
return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading session file: %s\n", err.what()); LLAMA_LOG_ERROR("error loading session file: %s\n", err.what());
return false; return false;
} }
} }
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
llama_file file(path_session, "wb"); llama_file file(path_session, "wb");
file.write_u32(LLAMA_SESSION_MAGIC); file.write_u32(LLAMA_SESSION_MAGIC);
@ -15054,7 +15054,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
// save the context state using stream saving // save the context state using stream saving
llama_data_file_context data_ctx(&file); llama_data_file_context data_ctx(&file);
llama_copy_state_data_internal(ctx, &data_ctx); llama_state_get_data_internal(ctx, &data_ctx);
return true; return true;
} }

10
llama.h
View file

@ -594,30 +594,30 @@ extern "C" {
// Returns the maximum size in bytes of the state (rng, logits, embedding // Returns the maximum size in bytes of the state (rng, logits, embedding
// and kv_cache) - will often be smaller after compacting tokens // and kv_cache) - will often be smaller after compacting tokens
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx); LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
// Copies the state to the specified destination address. // Copies the state to the specified destination address.
// Destination needs to have allocated enough memory. // Destination needs to have allocated enough memory.
// Returns the number of bytes copied // Returns the number of bytes copied
LLAMA_API size_t llama_copy_state_data( LLAMA_API size_t llama_state_get_data(
struct llama_context * ctx, struct llama_context * ctx,
uint8_t * dst); uint8_t * dst);
// Set the state reading from the specified address // Set the state reading from the specified address
// Returns the number of bytes read // Returns the number of bytes read
LLAMA_API size_t llama_set_state_data( LLAMA_API size_t llama_state_set_data(
struct llama_context * ctx, struct llama_context * ctx,
const uint8_t * src); const uint8_t * src);
// Save/load session file // Save/load session file
LLAMA_API bool llama_load_session_file( LLAMA_API bool llama_state_load_file(
struct llama_context * ctx, struct llama_context * ctx,
const char * path_session, const char * path_session,
llama_token * tokens_out, llama_token * tokens_out,
size_t n_token_capacity, size_t n_token_capacity,
size_t * n_token_count_out); size_t * n_token_count_out);
LLAMA_API bool llama_save_session_file( LLAMA_API bool llama_state_save_file(
struct llama_context * ctx, struct llama_context * ctx,
const char * path_session, const char * path_session,
const llama_token * tokens, const llama_token * tokens,