llama : more graceful error handling of invalid session files

* llama : remove LLAMA_MAX_RNG_STATE

It's no longer necessary to limit the size of the RNG state,
because the max size of session files is not estimated anymore.
This commit is contained in:
Francis Couture-Harpin 2024-07-27 14:31:57 -04:00
parent 83e6a17ddf
commit ffd5117def
2 changed files with 82 additions and 51 deletions

View file

@ -33,8 +33,6 @@
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF #define LLAMA_DEFAULT_SEED 0xFFFFFFFF
#define LLAMA_MAX_RNG_STATE (64*1024)
#define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla'
#define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn'
#define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq'

View file

@ -17302,18 +17302,16 @@ struct llama_data_write {
virtual size_t get_size_written() = 0; virtual size_t get_size_written() = 0;
virtual ~llama_data_write() = default; virtual ~llama_data_write() = default;
void write_string(const std::string & str, uint32_t max_size) { void write_string(const std::string & str) {
uint32_t str_size = str.size(); uint32_t str_size = str.size();
GGML_ASSERT(str_size <= max_size);
write(&str_size, sizeof(str_size)); write(&str_size, sizeof(str_size));
write(str.data(), str_size); write(str.data(), str_size);
} }
void write_model_info(const struct llama_context * ctx) { void write_model_info(const struct llama_context * ctx) {
std::string arch_str = LLM_ARCH_NAMES.at(ctx->model.arch); std::string arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
write_string(arch_str, arch_str.size()); write_string(arch_str);
// TODO: add more model-specific info which should prevent loading the session file if not identical // TODO: add more model-specific info which should prevent loading the session file if not identical
} }
@ -17323,7 +17321,7 @@ struct llama_data_write {
const std::string & rng_str = rng_ss.str(); const std::string & rng_str = rng_ss.str();
write_string(rng_str, LLAMA_MAX_RNG_STATE); write_string(rng_str);
} }
void write_output_ids(const struct llama_context * ctx) { void write_output_ids(const struct llama_context * ctx) {
@ -17527,12 +17525,10 @@ struct llama_data_read {
virtual size_t get_size_read() = 0; virtual size_t get_size_read() = 0;
virtual ~llama_data_read() = default; virtual ~llama_data_read() = default;
void read_string(std::string & str, uint32_t max_size) { void read_string(std::string & str) {
uint32_t str_size; uint32_t str_size;
read_to(&str_size, sizeof(str_size)); read_to(&str_size, sizeof(str_size));
GGML_ASSERT(str_size <= max_size);
str.assign((const char *) read(str_size), str_size); str.assign((const char *) read(str_size), str_size);
} }
@ -17540,7 +17536,7 @@ struct llama_data_read {
void read_model_info(const struct llama_context * ctx) { void read_model_info(const struct llama_context * ctx) {
std::string cur_arch_str = LLM_ARCH_NAMES.at(ctx->model.arch); std::string cur_arch_str = LLM_ARCH_NAMES.at(ctx->model.arch);
std::string arch_str; std::string arch_str;
read_string(arch_str, cur_arch_str.size()); read_string(arch_str);
if (cur_arch_str != arch_str) { if (cur_arch_str != arch_str) {
throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str()));
} }
@ -17549,12 +17545,14 @@ struct llama_data_read {
void read_rng(std::mt19937 & rng) { void read_rng(std::mt19937 & rng) {
std::string rng_str; std::string rng_str;
read_string(rng_str, LLAMA_MAX_RNG_STATE); read_string(rng_str);
std::istringstream rng_ss(rng_str); std::istringstream rng_ss(rng_str);
rng_ss >> rng; rng_ss >> rng;
GGML_ASSERT(!rng_ss.fail()); if (rng_ss.fail()) {
throw std::runtime_error("failed to load RNG state");
}
} }
void read_output_ids(struct llama_context * ctx) { void read_output_ids(struct llama_context * ctx) {
@ -17564,7 +17562,7 @@ struct llama_data_read {
read_to(&n_outputs, sizeof(n_outputs)); read_to(&n_outputs, sizeof(n_outputs));
if (n_outputs > llama_output_reserve(*ctx, n_outputs)) { if (n_outputs > llama_output_reserve(*ctx, n_outputs)) {
GGML_ASSERT(false && "could not reserve outputs"); throw std::runtime_error("could not reserve outputs");
} }
if (n_outputs) { if (n_outputs) {
@ -17573,7 +17571,9 @@ struct llama_data_read {
for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) {
int32_t id = output_pos[i]; int32_t id = output_pos[i];
GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch); if ((uint32_t) id >= ctx->cparams.n_batch) {
throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch));
}
ctx->output_ids[id] = i; ctx->output_ids[id] = i;
} }
@ -17585,7 +17585,9 @@ struct llama_data_read {
uint64_t logits_size; uint64_t logits_size;
read_to(&logits_size, sizeof(logits_size)); read_to(&logits_size, sizeof(logits_size));
GGML_ASSERT(ctx->logits_size >= logits_size); if (ctx->logits_size < logits_size) {
throw std::runtime_error("logits buffer too small");
}
if (logits_size) { if (logits_size) {
read_to(ctx->logits, logits_size * sizeof(float)); read_to(ctx->logits, logits_size * sizeof(float));
@ -17596,18 +17598,22 @@ struct llama_data_read {
uint64_t embeddings_size; uint64_t embeddings_size;
read_to(&embeddings_size, sizeof(embeddings_size)); read_to(&embeddings_size, sizeof(embeddings_size));
GGML_ASSERT(ctx->embd_size >= embeddings_size); if (ctx->embd_size < embeddings_size) {
throw std::runtime_error("embeddings buffer too small");
}
if (embeddings_size) { if (embeddings_size) {
read_to(ctx->embd, embeddings_size * sizeof(float)); read_to(ctx->embd, embeddings_size * sizeof(float));
} }
} }
bool read_kv_cache_meta(struct llama_kv_cache & kv_self, uint32_t cell_count, llama_seq_id seq_id = -1) { bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) {
if (seq_id != -1) { struct llama_kv_cache & kv_self = ctx->kv_self;
if (dest_seq_id != -1) {
// single sequence // single sequence
llama_kv_cache_seq_rm(kv_self, seq_id, -1, -1); llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
llama_batch batch = llama_batch_init(cell_count, 0, 1); llama_batch batch = llama_batch_init(cell_count, 0, 1);
batch.n_tokens = cell_count; batch.n_tokens = cell_count;
@ -17618,11 +17624,14 @@ struct llama_data_read {
read_to(&pos, sizeof(pos)); read_to(&pos, sizeof(pos));
read_to(&n_seq_id, sizeof(n_seq_id)); read_to(&n_seq_id, sizeof(n_seq_id));
GGML_ASSERT(n_seq_id == 0); if (n_seq_id != 0) {
LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__);
return false;
}
batch.pos[i] = pos; batch.pos[i] = pos;
batch.n_seq_id[i] = 1; batch.n_seq_id[i] = 1;
batch.seq_id[i][0] = seq_id; batch.seq_id[i][0] = dest_seq_id;
} }
if (!llama_kv_cache_find_slot(kv_self, batch)) { if (!llama_kv_cache_find_slot(kv_self, batch)) {
llama_batch_free(batch); llama_batch_free(batch);
@ -17635,8 +17644,8 @@ struct llama_data_read {
GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); GGML_ASSERT(kv_self.head + cell_count <= kv_self.size);
GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]);
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]);
GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(seq_id)); GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id));
GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(seq_id)); GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id));
// Cleanup // Cleanup
llama_batch_free(batch); llama_batch_free(batch);
@ -17661,12 +17670,16 @@ struct llama_data_read {
cell.pos = pos; cell.pos = pos;
// TODO: more sanity checks for seq_ids
for (uint32_t j = 0; j < n_seq_id; ++j) { for (uint32_t j = 0; j < n_seq_id; ++j) {
llama_seq_id s; llama_seq_id seq_id;
read_to(&s, sizeof(s)); read_to(&seq_id, sizeof(seq_id));
cell.seq_id.insert(s); if (seq_id < 0 || seq_id >= llama_n_seq_max(ctx)) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
return false;
}
cell.seq_id.insert(seq_id);
} }
} }
@ -17685,9 +17698,18 @@ struct llama_data_read {
read_to(&v_trans, sizeof(v_trans)); read_to(&v_trans, sizeof(v_trans));
read_to(&n_layer, sizeof(n_layer)); read_to(&n_layer, sizeof(n_layer));
GGML_ASSERT(n_layer == hparams.n_layer); if (n_layer != hparams.n_layer) {
GGML_ASSERT(cell_count <= kv_self.size); LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer);
GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition return false;
}
if (cell_count > kv_self.size) {
LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size);
return false;
}
if (kv_self.v_trans != (bool) v_trans) {
LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__);
return false;
}
// For each layer, read the keys for each cell, one row is one cell, read as one contiguous block // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block
for (uint32_t il = 0; il < n_layer; ++il) { for (uint32_t il = 0; il < n_layer; ++il) {
@ -17698,7 +17720,6 @@ struct llama_data_read {
read_to(&k_type_i_ref, sizeof(k_type_i_ref)); read_to(&k_type_i_ref, sizeof(k_type_i_ref));
const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type;
if (k_type_i != k_type_i_ref) { if (k_type_i != k_type_i_ref) {
// llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il);
return false; return false;
} }
@ -17708,7 +17729,6 @@ struct llama_data_read {
read_to(&k_size_row_ref, sizeof(k_size_row_ref)); read_to(&k_size_row_ref, sizeof(k_size_row_ref));
const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa);
if (k_size_row != k_size_row_ref) { if (k_size_row != k_size_row_ref) {
// llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il);
return false; return false;
} }
@ -17728,7 +17748,6 @@ struct llama_data_read {
read_to(&v_type_i_ref, sizeof(v_type_i_ref)); read_to(&v_type_i_ref, sizeof(v_type_i_ref));
const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type;
if (v_type_i != v_type_i_ref) { if (v_type_i != v_type_i_ref) {
// llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1);
LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il);
return false; return false;
} }
@ -17790,11 +17809,11 @@ struct llama_data_read {
return true; return true;
} }
bool read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) {
uint32_t cell_count; uint32_t cell_count;
read_to(&cell_count, sizeof(cell_count)); read_to(&cell_count, sizeof(cell_count));
bool res = read_kv_cache_meta(ctx->kv_self, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count);
if (!res) { if (!res) {
if (seq_id == -1) { if (seq_id == -1) {
@ -17802,9 +17821,8 @@ struct llama_data_read {
} else { } else {
llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); llama_kv_cache_seq_rm(ctx, seq_id, -1, -1);
} }
throw std::runtime_error("failed to restore kv cache");
} }
return res;
} }
}; };
@ -17944,14 +17962,24 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da
size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) { size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) {
llama_data_write_buffer data_ctx(dst, size); llama_data_write_buffer data_ctx(dst, size);
try {
return llama_state_get_data_internal(ctx, data_ctx); return llama_state_get_data_internal(ctx, data_ctx);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what());
return 0;
}
} }
// Returns the *actual* size of the state. // Returns the *actual* size of the state.
// Intended to be used when saving to state to a buffer. // Intended to be used when saving to state to a buffer.
size_t llama_state_get_size(struct llama_context * ctx) { size_t llama_state_get_size(struct llama_context * ctx) {
llama_data_write_dummy data_ctx; llama_data_write_dummy data_ctx;
try {
return llama_state_get_data_internal(ctx, data_ctx); return llama_state_get_data_internal(ctx, data_ctx);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what());
return 0;
}
} }
static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) { static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) {
@ -17975,7 +18003,12 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da
// Sets the state reading from the specified source address // Sets the state reading from the specified source address
size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) { size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) {
llama_data_read_buffer data_ctx(src, size); llama_data_read_buffer data_ctx(src, size);
try {
return llama_state_set_data_internal(ctx, data_ctx); return llama_state_set_data_internal(ctx, data_ctx);
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what());
return 0;
}
} }
static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) {
@ -18024,7 +18057,7 @@ bool llama_state_load_file(struct llama_context * ctx, const char * path_session
try { try {
return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading session file: %s\n", err.what()); LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what());
return false; return false;
} }
} }
@ -18050,7 +18083,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session
try { try {
return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count); return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("error saving session file: %s\n", err.what()); LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what());
return false; return false;
} }
} }
@ -18073,7 +18106,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_
try { try {
return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("error saving sequence state: %s\n", err.what()); LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what());
return 0; return 0;
} }
} }
@ -18091,7 +18124,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src,
try { try {
return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading sequence state: %s\n", err.what()); LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what());
return 0; return 0;
} }
} }
@ -18162,7 +18195,7 @@ size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepa
try { try {
return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count); return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what()); LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what());
return 0; return 0;
} }
} }
@ -18171,7 +18204,7 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa
try { try {
return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out);
} catch (const std::exception & err) { } catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what()); LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what());
return 0; return 0;
} }
} }