catch exceptions on save as well

This commit is contained in:
Jan Boon 2024-04-02 04:06:23 +08:00
parent 8af72118ec
commit 3d6fa5bdd7

View file

@ -15064,7 +15064,7 @@ bool llama_state_load_file(struct llama_context * ctx, const char * path_session
} }
} }
bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { static bool llama_state_save_file_internal(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
llama_file file(path_session, "wb"); llama_file file(path_session, "wb");
file.write_u32(LLAMA_SESSION_MAGIC); file.write_u32(LLAMA_SESSION_MAGIC);
@ -15083,6 +15083,15 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
return true; return true;
} }
bool llama_state_save_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
try {
return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("error saving session file: %s\n", err.what());
return false;
}
}
size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) { size_t llama_state_seq_get_size(struct llama_context* ctx, llama_seq_id seq_id) {
// save the size of size_t as a uint32_t for safety check // save the size of size_t as a uint32_t for safety check
const size_t size_t_size_size = sizeof(uint32_t); const size_t size_t_size_size = sizeof(uint32_t);
@ -15365,7 +15374,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
return nread; return nread;
} }
size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) { static size_t llama_state_seq_save_file_internal(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
llama_file file(filepath, "wb"); llama_file file(filepath, "wb");
file.write_u32(LLAMA_STATE_SEQ_MAGIC); file.write_u32(LLAMA_STATE_SEQ_MAGIC);
@ -15428,12 +15437,21 @@ static size_t llama_state_seq_load_file_internal(struct llama_context * ctx, con
return file.tell(); return file.tell();
} }
size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepath, llama_seq_id seq_id, const llama_token * tokens, size_t n_token_count) {
try {
return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what());
return 0;
}
}
size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepath, llama_seq_id dest_seq_id, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
try { try {
return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what()); LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what());
return false; return 0;
} }
} }