rename state get set functions
This commit is contained in:
parent
c4443d7ad4
commit
4d5356bbbb
4 changed files with 29 additions and 29 deletions
|
@ -235,7 +235,7 @@ int main(int argc, char ** argv) {
|
||||||
// The file exists and is not empty
|
// The file exists and is not empty
|
||||||
session_tokens.resize(n_ctx);
|
session_tokens.resize(n_ctx);
|
||||||
size_t n_token_count_out = 0;
|
size_t n_token_count_out = 0;
|
||||||
if (!llama_load_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
|
if (!llama_state_load_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.capacity(), &n_token_count_out)) {
|
||||||
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
|
LOG_TEE("%s: error: failed to load session file '%s'\n", __func__, path_session.c_str());
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -693,7 +693,7 @@ int main(int argc, char ** argv) {
|
||||||
// optionally save the session on first sample (for faster prompt loading next time)
|
// optionally save the session on first sample (for faster prompt loading next time)
|
||||||
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
|
if (!path_session.empty() && need_to_save_session && !params.prompt_cache_ro) {
|
||||||
need_to_save_session = false;
|
need_to_save_session = false;
|
||||||
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||||
|
|
||||||
LOG("saved session to %s\n", path_session.c_str());
|
LOG("saved session to %s\n", path_session.c_str());
|
||||||
}
|
}
|
||||||
|
@ -935,7 +935,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
|
if (!path_session.empty() && params.prompt_cache_all && !params.prompt_cache_ro) {
|
||||||
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
|
LOG_TEE("\n%s: saving final output to session file '%s'\n", __func__, path_session.c_str());
|
||||||
llama_save_session_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
|
|
|
@ -45,8 +45,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// save state (rng, logits, embedding and kv_cache) to file
|
// save state (rng, logits, embedding and kv_cache) to file
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> state_mem(llama_get_state_size(ctx));
|
std::vector<uint8_t> state_mem(llama_state_get_size(ctx));
|
||||||
const size_t written = llama_copy_state_data(ctx, state_mem.data());
|
const size_t written = llama_state_get_data(ctx, state_mem.data());
|
||||||
|
|
||||||
FILE *fp_write = fopen("dump_state.bin", "wb");
|
FILE *fp_write = fopen("dump_state.bin", "wb");
|
||||||
fwrite(state_mem.data(), 1, written, fp_write);
|
fwrite(state_mem.data(), 1, written, fp_write);
|
||||||
|
@ -98,13 +98,13 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// load state (rng, logits, embedding and kv_cache) from file
|
// load state (rng, logits, embedding and kv_cache) from file
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> state_mem(llama_get_state_size(ctx2));
|
std::vector<uint8_t> state_mem(llama_state_get_size(ctx2));
|
||||||
|
|
||||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||||
fclose(fp_read);
|
fclose(fp_read);
|
||||||
|
|
||||||
if (read != llama_set_state_data(ctx2, state_mem.data())) {
|
if (read != llama_state_set_data(ctx2, state_mem.data())) {
|
||||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||||
llama_free(ctx2);
|
llama_free(ctx2);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
@ -158,13 +158,13 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// load state (rng, logits, embedding and kv_cache) from file
|
// load state (rng, logits, embedding and kv_cache) from file
|
||||||
{
|
{
|
||||||
std::vector<uint8_t> state_mem(llama_get_state_size(ctx3));
|
std::vector<uint8_t> state_mem(llama_state_get_size(ctx3));
|
||||||
|
|
||||||
FILE * fp_read = fopen("dump_state.bin", "rb");
|
FILE * fp_read = fopen("dump_state.bin", "rb");
|
||||||
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read);
|
||||||
fclose(fp_read);
|
fclose(fp_read);
|
||||||
|
|
||||||
if (read != llama_set_state_data(ctx3, state_mem.data())) {
|
if (read != llama_state_set_data(ctx3, state_mem.data())) {
|
||||||
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
fprintf(stderr, "\n%s : failed to read state\n", __func__);
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
30
llama.cpp
30
llama.cpp
|
@ -14570,7 +14570,7 @@ void llama_kv_cache_update(struct llama_context * ctx) {
|
||||||
|
|
||||||
|
|
||||||
// Returns the *maximum* size of the state
|
// Returns the *maximum* size of the state
|
||||||
size_t llama_get_state_size(const struct llama_context * ctx) {
|
size_t llama_state_get_size(const struct llama_context * ctx) {
|
||||||
const auto & cparams = ctx->cparams;
|
const auto & cparams = ctx->cparams;
|
||||||
const auto & hparams = ctx->model.hparams;
|
const auto & hparams = ctx->model.hparams;
|
||||||
|
|
||||||
|
@ -14658,15 +14658,15 @@ struct llama_data_file_context : llama_data_context {
|
||||||
* file context:
|
* file context:
|
||||||
* llama_file file("/path", "wb");
|
* llama_file file("/path", "wb");
|
||||||
* llama_data_file_context data_ctx(&file);
|
* llama_data_file_context data_ctx(&file);
|
||||||
* llama_copy_state_data(ctx, &data_ctx);
|
* llama_state_get_data(ctx, &data_ctx);
|
||||||
*
|
*
|
||||||
* buffer context:
|
* buffer context:
|
||||||
* std::vector<uint8_t> buf(max_size, 0);
|
* std::vector<uint8_t> buf(max_size, 0);
|
||||||
* llama_data_buffer_context data_ctx(&buf.data());
|
* llama_data_buffer_context data_ctx(&buf.data());
|
||||||
* llama_copy_state_data(ctx, &data_ctx);
|
* llama_state_get_data(ctx, &data_ctx);
|
||||||
*
|
*
|
||||||
*/
|
*/
|
||||||
static void llama_copy_state_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
|
static void llama_state_get_data_internal(struct llama_context * ctx, llama_data_context * data_ctx) {
|
||||||
// copy rng
|
// copy rng
|
||||||
{
|
{
|
||||||
std::ostringstream rng_ss;
|
std::ostringstream rng_ss;
|
||||||
|
@ -14810,15 +14810,15 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) {
|
size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst) {
|
||||||
llama_data_buffer_context data_ctx(dst);
|
llama_data_buffer_context data_ctx(dst);
|
||||||
llama_copy_state_data_internal(ctx, &data_ctx);
|
llama_state_get_data_internal(ctx, &data_ctx);
|
||||||
|
|
||||||
return data_ctx.get_size_written();
|
return data_ctx.get_size_written();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sets the state reading from the specified source address
|
// Sets the state reading from the specified source address
|
||||||
size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
const uint8_t * inp = src;
|
const uint8_t * inp = src;
|
||||||
|
|
||||||
// set rng
|
// set rng
|
||||||
|
@ -14970,14 +14970,14 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) {
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t nread = inp - src;
|
const size_t nread = inp - src;
|
||||||
const size_t max_size = llama_get_state_size(ctx);
|
const size_t max_size = llama_state_get_size(ctx);
|
||||||
|
|
||||||
GGML_ASSERT(nread <= max_size);
|
GGML_ASSERT(nread <= max_size);
|
||||||
|
|
||||||
return nread;
|
return nread;
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||||
llama_file file(path_session, "rb");
|
llama_file file(path_session, "rb");
|
||||||
|
|
||||||
// sanity checks
|
// sanity checks
|
||||||
|
@ -15015,7 +15015,7 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c
|
||||||
// restore the context state
|
// restore the context state
|
||||||
{
|
{
|
||||||
const size_t n_state_size_cur = file.size - file.tell();
|
const size_t n_state_size_cur = file.size - file.tell();
|
||||||
const size_t n_state_size_max = llama_get_state_size(ctx);
|
const size_t n_state_size_max = llama_state_get_size(ctx);
|
||||||
|
|
||||||
if (n_state_size_cur > n_state_size_max) {
|
if (n_state_size_cur > n_state_size_max) {
|
||||||
LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
|
LLAMA_LOG_ERROR("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
|
||||||
|
@ -15025,22 +15025,22 @@ static bool llama_load_session_file_internal(struct llama_context * ctx, const c
|
||||||
std::vector<uint8_t> state_data(n_state_size_max);
|
std::vector<uint8_t> state_data(n_state_size_max);
|
||||||
file.read_raw(state_data.data(), n_state_size_cur);
|
file.read_raw(state_data.data(), n_state_size_cur);
|
||||||
|
|
||||||
llama_set_state_data(ctx, state_data.data());
|
llama_state_set_data(ctx, state_data.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
bool llama_state_load_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||||
try {
|
try {
|
||||||
return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
||||||
} catch (const std::exception & err) {
|
} catch (const std::exception & err) {
|
||||||
LLAMA_LOG_ERROR("error loading session file: %s\n", err.what());
|
LLAMA_LOG_ERROR("error loading session file: %s\n", err.what());
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
||||||
llama_file file(path_session, "wb");
|
llama_file file(path_session, "wb");
|
||||||
|
|
||||||
file.write_u32(LLAMA_SESSION_MAGIC);
|
file.write_u32(LLAMA_SESSION_MAGIC);
|
||||||
|
@ -15054,7 +15054,7 @@ bool llama_save_session_file(struct llama_context * ctx, const char * path_sessi
|
||||||
|
|
||||||
// save the context state using stream saving
|
// save the context state using stream saving
|
||||||
llama_data_file_context data_ctx(&file);
|
llama_data_file_context data_ctx(&file);
|
||||||
llama_copy_state_data_internal(ctx, &data_ctx);
|
llama_state_get_data_internal(ctx, &data_ctx);
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
10
llama.h
10
llama.h
|
@ -594,30 +594,30 @@ extern "C" {
|
||||||
|
|
||||||
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
// Returns the maximum size in bytes of the state (rng, logits, embedding
|
||||||
// and kv_cache) - will often be smaller after compacting tokens
|
// and kv_cache) - will often be smaller after compacting tokens
|
||||||
LLAMA_API size_t llama_get_state_size(const struct llama_context * ctx);
|
LLAMA_API size_t llama_state_get_size(const struct llama_context * ctx);
|
||||||
|
|
||||||
// Copies the state to the specified destination address.
|
// Copies the state to the specified destination address.
|
||||||
// Destination needs to have allocated enough memory.
|
// Destination needs to have allocated enough memory.
|
||||||
// Returns the number of bytes copied
|
// Returns the number of bytes copied
|
||||||
LLAMA_API size_t llama_copy_state_data(
|
LLAMA_API size_t llama_state_get_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
uint8_t * dst);
|
uint8_t * dst);
|
||||||
|
|
||||||
// Set the state reading from the specified address
|
// Set the state reading from the specified address
|
||||||
// Returns the number of bytes read
|
// Returns the number of bytes read
|
||||||
LLAMA_API size_t llama_set_state_data(
|
LLAMA_API size_t llama_state_set_data(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const uint8_t * src);
|
const uint8_t * src);
|
||||||
|
|
||||||
// Save/load session file
|
// Save/load session file
|
||||||
LLAMA_API bool llama_load_session_file(
|
LLAMA_API bool llama_state_load_file(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const char * path_session,
|
const char * path_session,
|
||||||
llama_token * tokens_out,
|
llama_token * tokens_out,
|
||||||
size_t n_token_capacity,
|
size_t n_token_capacity,
|
||||||
size_t * n_token_count_out);
|
size_t * n_token_count_out);
|
||||||
|
|
||||||
LLAMA_API bool llama_save_session_file(
|
LLAMA_API bool llama_state_save_file(
|
||||||
struct llama_context * ctx,
|
struct llama_context * ctx,
|
||||||
const char * path_session,
|
const char * path_session,
|
||||||
const llama_token * tokens,
|
const llama_token * tokens,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue