convert checks in llama_load_session_file to throw and handle them
This commit is contained in:
parent
0be54f75a6
commit
ced65a56b0
1 changed files with 14 additions and 10 deletions
22
llama.cpp
22
llama.cpp
|
@ -3336,7 +3336,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
|
|||
return nread;
|
||||
}
|
||||
|
||||
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
void llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
llama_file file(path_session, "rb");
|
||||
|
||||
// sanity checks
|
||||
|
@ -3345,16 +3345,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
|
|||
const uint32_t version = file.read_u32();
|
||||
|
||||
if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) {
|
||||
fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version);
|
||||
return false;
|
||||
throw std::runtime_error(format("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version));
|
||||
}
|
||||
|
||||
llama_hparams session_hparams;
|
||||
file.read_raw(&session_hparams, sizeof(llama_hparams));
|
||||
|
||||
if (session_hparams != ctx->model.hparams) {
|
||||
fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__);
|
||||
return false;
|
||||
throw std::runtime_error(format("%s : model hparams didn't match from session file!\n", __func__));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3363,8 +3361,7 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
|
|||
const uint32_t n_token_count = file.read_u32();
|
||||
|
||||
if (n_token_count > n_token_capacity) {
|
||||
fprintf(stderr, "%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity);
|
||||
return false;
|
||||
throw std::runtime_error(format("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity));
|
||||
}
|
||||
|
||||
file.read_raw(tokens_out, sizeof(llama_token) * n_token_count);
|
||||
|
@ -3377,8 +3374,7 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
|
|||
const size_t n_state_size_max = llama_get_state_size(ctx);
|
||||
|
||||
if (n_state_size_cur > n_state_size_max) {
|
||||
fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur);
|
||||
return false;
|
||||
throw std::runtime_error(format("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur));
|
||||
}
|
||||
|
||||
std::vector<uint8_t> state_data(n_state_size_max);
|
||||
|
@ -3386,8 +3382,16 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi
|
|||
|
||||
llama_set_state_data(ctx, state_data.data());
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
|
||||
try {
|
||||
llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
|
||||
return true;
|
||||
} catch (const std::exception & err) {
|
||||
fprintf(stderr, "error loading session file: %s\n", err.what());
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue