From ffd5117def1fb91620e21c29707b55c781342e0b Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Sat, 27 Jul 2024 14:31:57 -0400 Subject: [PATCH] llama : more graceful error handling of invalid session files * llama : remove LLAMA_MAX_RNG_STATE It's no longer necessary to limit the size of the RNG state, because the max size of session files is not estimated anymore. --- include/llama.h | 2 - src/llama.cpp | 131 ++++++++++++++++++++++++++++++------------------ 2 files changed, 82 insertions(+), 51 deletions(-) diff --git a/include/llama.h b/include/llama.h index 578acc405..29c03ce98 100644 --- a/include/llama.h +++ b/include/llama.h @@ -33,8 +33,6 @@ #define LLAMA_DEFAULT_SEED 0xFFFFFFFF -#define LLAMA_MAX_RNG_STATE (64*1024) - #define LLAMA_FILE_MAGIC_GGLA 0x67676c61u // 'ggla' #define LLAMA_FILE_MAGIC_GGSN 0x6767736eu // 'ggsn' #define LLAMA_FILE_MAGIC_GGSQ 0x67677371u // 'ggsq' diff --git a/src/llama.cpp b/src/llama.cpp index a20eaaefd..ba131951c 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -17302,18 +17302,16 @@ struct llama_data_write { virtual size_t get_size_written() = 0; virtual ~llama_data_write() = default; - void write_string(const std::string & str, uint32_t max_size) { + void write_string(const std::string & str) { uint32_t str_size = str.size(); - GGML_ASSERT(str_size <= max_size); - write(&str_size, sizeof(str_size)); write(str.data(), str_size); } void write_model_info(const struct llama_context * ctx) { std::string arch_str = LLM_ARCH_NAMES.at(ctx->model.arch); - write_string(arch_str, arch_str.size()); + write_string(arch_str); // TODO: add more model-specific info which should prevent loading the session file if not identical } @@ -17321,9 +17319,9 @@ struct llama_data_write { std::ostringstream rng_ss; rng_ss << rng; - const std::string & rng_str = rng_ss.str(); + const std::string & rng_str = rng_ss.str(); - write_string(rng_str, LLAMA_MAX_RNG_STATE); + write_string(rng_str); } void write_output_ids(const struct llama_context * ctx) { @@ -17527,12 +17525,10 @@ struct llama_data_read { virtual size_t get_size_read() = 0; virtual ~llama_data_read() = default; - void read_string(std::string & str, uint32_t max_size) { + void read_string(std::string & str) { uint32_t str_size; read_to(&str_size, sizeof(str_size)); - GGML_ASSERT(str_size <= max_size); - str.assign((const char *) read(str_size), str_size); } @@ -17540,7 +17536,7 @@ struct llama_data_read { void read_model_info(const struct llama_context * ctx) { std::string cur_arch_str = LLM_ARCH_NAMES.at(ctx->model.arch); std::string arch_str; - read_string(arch_str, cur_arch_str.size()); + read_string(arch_str); if (cur_arch_str != arch_str) { throw std::runtime_error(format("wrong model arch: '%s' instead of '%s'", arch_str.c_str(), cur_arch_str.c_str())); } @@ -17549,12 +17545,14 @@ struct llama_data_read { void read_rng(std::mt19937 & rng) { std::string rng_str; - read_string(rng_str, LLAMA_MAX_RNG_STATE); + read_string(rng_str); std::istringstream rng_ss(rng_str); rng_ss >> rng; - GGML_ASSERT(!rng_ss.fail()); + if (rng_ss.fail()) { + throw std::runtime_error("failed to load RNG state"); + } } void read_output_ids(struct llama_context * ctx) { @@ -17564,7 +17562,7 @@ struct llama_data_read { read_to(&n_outputs, sizeof(n_outputs)); if (n_outputs > llama_output_reserve(*ctx, n_outputs)) { - GGML_ASSERT(false && "could not reserve outputs"); + throw std::runtime_error("could not reserve outputs"); } if (n_outputs) { @@ -17573,7 +17571,9 @@ struct llama_data_read { for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { int32_t id = output_pos[i]; - GGML_ASSERT((uint32_t) id < ctx->cparams.n_batch); + if ((uint32_t) id >= ctx->cparams.n_batch) { + throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, ctx->cparams.n_batch)); + } ctx->output_ids[id] = i; } @@ -17585,7 +17585,9 @@ struct llama_data_read { uint64_t logits_size; read_to(&logits_size, sizeof(logits_size)); - GGML_ASSERT(ctx->logits_size >= logits_size); + if (ctx->logits_size < logits_size) { + throw std::runtime_error("logits buffer too small"); + } if (logits_size) { read_to(ctx->logits, logits_size * sizeof(float)); @@ -17596,18 +17598,22 @@ struct llama_data_read { uint64_t embeddings_size; read_to(&embeddings_size, sizeof(embeddings_size)); - GGML_ASSERT(ctx->embd_size >= embeddings_size); + if (ctx->embd_size < embeddings_size) { + throw std::runtime_error("embeddings buffer too small"); + } if (embeddings_size) { read_to(ctx->embd, embeddings_size * sizeof(float)); } } - bool read_kv_cache_meta(struct llama_kv_cache & kv_self, uint32_t cell_count, llama_seq_id seq_id = -1) { - if (seq_id != -1) { + bool read_kv_cache_meta(struct llama_context * ctx, uint32_t cell_count, llama_seq_id dest_seq_id = -1) { + struct llama_kv_cache & kv_self = ctx->kv_self; + + if (dest_seq_id != -1) { // single sequence - llama_kv_cache_seq_rm(kv_self, seq_id, -1, -1); + llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); llama_batch batch = llama_batch_init(cell_count, 0, 1); batch.n_tokens = cell_count; @@ -17618,11 +17624,14 @@ struct llama_data_read { read_to(&pos, sizeof(pos)); read_to(&n_seq_id, sizeof(n_seq_id)); - GGML_ASSERT(n_seq_id == 0); + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } batch.pos[i] = pos; batch.n_seq_id[i] = 1; - batch.seq_id[i][0] = seq_id; + batch.seq_id[i][0] = dest_seq_id; } if (!llama_kv_cache_find_slot(kv_self, batch)) { llama_batch_free(batch); @@ -17635,8 +17644,8 @@ struct llama_data_read { GGML_ASSERT(kv_self.head + cell_count <= kv_self.size); GGML_ASSERT(kv_self.cells[kv_self.head].pos == batch.pos[0]); GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].pos == batch.pos[cell_count - 1]); - GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(seq_id)); - GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(seq_id)); + GGML_ASSERT(kv_self.cells[kv_self.head].has_seq_id(dest_seq_id)); + GGML_ASSERT(kv_self.cells[kv_self.head + cell_count - 1].has_seq_id(dest_seq_id)); // Cleanup llama_batch_free(batch); @@ -17661,12 +17670,16 @@ struct llama_data_read { cell.pos = pos; - // TODO: more sanity checks for seq_ids for (uint32_t j = 0; j < n_seq_id; ++j) { - llama_seq_id s; - read_to(&s, sizeof(s)); + llama_seq_id seq_id; + read_to(&seq_id, sizeof(seq_id)); - cell.seq_id.insert(s); + if (seq_id < 0 || seq_id >= llama_n_seq_max(ctx)) { + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + return false; + } + + cell.seq_id.insert(seq_id); } } @@ -17685,9 +17698,18 @@ struct llama_data_read { read_to(&v_trans, sizeof(v_trans)); read_to(&n_layer, sizeof(n_layer)); - GGML_ASSERT(n_layer == hparams.n_layer); - GGML_ASSERT(cell_count <= kv_self.size); - GGML_ASSERT(kv_self.v_trans == (bool) v_trans); // incompatible V transposition + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > kv_self.size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, kv_self.size); + return false; + } + if (kv_self.v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block for (uint32_t il = 0; il < n_layer; ++il) { @@ -17698,7 +17720,6 @@ struct llama_data_read { read_to(&k_type_i_ref, sizeof(k_type_i_ref)); const int32_t k_type_i = (int32_t)kv_self.k_l[il]->type; if (k_type_i != k_type_i_ref) { - // llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); return false; } @@ -17708,7 +17729,6 @@ struct llama_data_read { read_to(&k_size_row_ref, sizeof(k_size_row_ref)); const size_t k_size_row = ggml_row_size(kv_self.k_l[il]->type, n_embd_k_gqa); if (k_size_row != k_size_row_ref) { - // llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); return false; } @@ -17728,7 +17748,6 @@ struct llama_data_read { read_to(&v_type_i_ref, sizeof(v_type_i_ref)); const int32_t v_type_i = (int32_t)kv_self.v_l[il]->type; if (v_type_i != v_type_i_ref) { - // llama_kv_cache_seq_rm(kv_self, dest_seq_id, -1, -1); LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); return false; } @@ -17790,11 +17809,11 @@ struct llama_data_read { return true; } - bool read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { + void read_kv_cache(struct llama_context * ctx, llama_seq_id seq_id = -1) { uint32_t cell_count; read_to(&cell_count, sizeof(cell_count)); - bool res = read_kv_cache_meta(ctx->kv_self, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); + bool res = read_kv_cache_meta(ctx, cell_count, seq_id) && read_kv_cache_data(ctx, cell_count); if (!res) { if (seq_id == -1) { @@ -17802,9 +17821,8 @@ struct llama_data_read { } else { llama_kv_cache_seq_rm(ctx, seq_id, -1, -1); } + throw std::runtime_error("failed to restore kv cache"); } - - return res; } }; @@ -17944,14 +17962,24 @@ static size_t llama_state_get_data_internal(struct llama_context * ctx, llama_da size_t llama_state_get_data(struct llama_context * ctx, uint8_t * dst, size_t size) { llama_data_write_buffer data_ctx(dst, size); - return llama_state_get_data_internal(ctx, data_ctx); + try { + return llama_state_get_data_internal(ctx, data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error saving state: %s\n", __func__, err.what()); + return 0; + } } // Returns the *actual* size of the state. // Intended to be used when saving to state to a buffer. size_t llama_state_get_size(struct llama_context * ctx) { llama_data_write_dummy data_ctx; - return llama_state_get_data_internal(ctx, data_ctx); + try { + return llama_state_get_data_internal(ctx, data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error getting state size: %s\n", __func__, err.what()); + return 0; + } } static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_data_read & data_ctx) { @@ -17975,7 +18003,12 @@ static size_t llama_state_set_data_internal(struct llama_context * ctx, llama_da // Sets the state reading from the specified source address size_t llama_state_set_data(struct llama_context * ctx, const uint8_t * src, size_t size) { llama_data_read_buffer data_ctx(src, size); - return llama_state_set_data_internal(ctx, data_ctx); + try { + return llama_state_set_data_internal(ctx, data_ctx); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading state: %s\n", __func__, err.what()); + return 0; + } } static bool llama_state_load_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { @@ -17987,7 +18020,7 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha const uint32_t version = file.read_u32(); if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { - LLAMA_LOG_ERROR("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); + LLAMA_LOG_ERROR("%s: unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); return false; } } @@ -17997,7 +18030,7 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha const uint32_t n_token_count = file.read_u32(); if (n_token_count > n_token_capacity) { - LLAMA_LOG_ERROR("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + LLAMA_LOG_ERROR("%s: token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); return false; } @@ -18013,7 +18046,7 @@ static bool llama_state_load_file_internal(struct llama_context * ctx, const cha const size_t n_read = llama_state_set_data_internal(ctx, data_ctx); if (n_read != n_state_size_cur) { - LLAMA_LOG_ERROR("%s : did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); + LLAMA_LOG_ERROR("%s: did not read all of the session file data! size %zu, got %zu\n", __func__, n_state_size_cur, n_read); return false; } } @@ -18024,7 +18057,7 @@ bool llama_state_load_file(struct llama_context * ctx, const char * path_session try { return llama_state_load_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { - LLAMA_LOG_ERROR("error loading session file: %s\n", err.what()); + LLAMA_LOG_ERROR("%s: error loading session file: %s\n", __func__, err.what()); return false; } } @@ -18050,7 +18083,7 @@ bool llama_state_save_file(struct llama_context * ctx, const char * path_session try { return llama_state_save_file_internal(ctx, path_session, tokens, n_token_count); } catch (const std::exception & err) { - LLAMA_LOG_ERROR("error saving session file: %s\n", err.what()); + LLAMA_LOG_ERROR("%s: error saving session file: %s\n", __func__, err.what()); return false; } } @@ -18073,7 +18106,7 @@ size_t llama_state_seq_get_data(struct llama_context * ctx, uint8_t * dst, size_ try { return llama_state_seq_get_data_internal(ctx, data_ctx, seq_id); } catch (const std::exception & err) { - LLAMA_LOG_ERROR("error saving sequence state: %s\n", err.what()); + LLAMA_LOG_ERROR("%s: error saving sequence state: %s\n", __func__, err.what()); return 0; } } @@ -18091,7 +18124,7 @@ size_t llama_state_seq_set_data(struct llama_context * ctx, const uint8_t * src, try { return llama_state_seq_set_data_internal(ctx, data_ctx, dest_seq_id); } catch (const std::exception & err) { - LLAMA_LOG_ERROR("error loading sequence state: %s\n", err.what()); + LLAMA_LOG_ERROR("%s: error loading sequence state: %s\n", __func__, err.what()); return 0; } } @@ -18162,7 +18195,7 @@ size_t llama_state_seq_save_file(struct llama_context * ctx, const char * filepa try { return llama_state_seq_save_file_internal(ctx, filepath, seq_id, tokens, n_token_count); } catch (const std::exception & err) { - LLAMA_LOG_ERROR("error saving sequence state file: %s\n", err.what()); + LLAMA_LOG_ERROR("%s: error saving sequence state file: %s\n", __func__, err.what()); return 0; } } @@ -18171,7 +18204,7 @@ size_t llama_state_seq_load_file(struct llama_context * ctx, const char * filepa try { return llama_state_seq_load_file_internal(ctx, filepath, dest_seq_id, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { - LLAMA_LOG_ERROR("error loading sequence state file: %s\n", err.what()); + LLAMA_LOG_ERROR("%s: error loading sequence state file: %s\n", __func__, err.what()); return 0; } }