diff --git a/llama.cpp b/llama.cpp index 5e1747842..615804b56 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15064,7 +15064,7 @@ bool llama_state_load_file(struct llama_context * ctx, const char * path_session } } -bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { +static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { llama_file file(path_session, "wb"); file.write_u32(LLAMA_SESSION_MAGIC); @@ -15083,6 +15083,15 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session return true; } +bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { + try { + return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error saving session file: %s\n", err.what()); + return false; + } +} + size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) { // save the size of size_t as a uint32_t for safety check const size_t size_t_size_size = sizeof(uint32_t); @@ -15365,7 +15374,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, return nread; } -size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { +static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { llama_file file(filepath, "wb"); file.write_u32(LLAMA_STATE_SEQ_MAGIC); @@ -15428,12 +15437,21 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con return file.tell(); } +size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { + try { + return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what()); + return 0; + } +} + size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { try { return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what()); - return false; + return 0; } }