From 615a3a4a50765d9fbb895ddd34f943c8bc92ce7e Mon Sep 17 00:00:00 2001 From: Francis Couture-Harpin Date: Tue, 19 Mar 2024 15:01:21 -0400 Subject: [PATCH] llama : clearer error messages for invalid logits or embeddings ids * llama : assert all models that can have inp_out_ids Since the graph topology is now constant, this presence check can be done even when there are no outputs. * llama : assert logits and embd buffers exist before writing to them --- llama.cpp | 66 +++++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 47 insertions(+), 19 deletions(-) diff --git a/llama.cpp b/llama.cpp index faf65e339..157ddadd8 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8954,7 +8954,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { ggml_backend_tensor_set(lctx.inp_pos, batch.pos, 0, n_tokens*ggml_element_size(lctx.inp_pos)); } - if (lctx.n_outputs > 0 && (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE)) { + if (hparams.causal_attn || cparams.pooling_type == LLAMA_POOLING_TYPE_NONE) { GGML_ASSERT(lctx.inp_out_ids && "every model that can must skip unused outputs"); const int64_t n_tokens = batch.n_tokens; @@ -9514,6 +9514,7 @@ static int llama_decode_internal( if (res) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(lctx.sched, res); GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(lctx.logits != nullptr); float * logits_out = lctx.logits + n_outputs_prev*n_vocab; const int32_t n_outputs_new = lctx.n_outputs; @@ -9534,6 +9535,7 @@ static int llama_decode_internal( case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings + GGML_ASSERT(lctx.embd != nullptr); float * embd_out = lctx.embd + n_outputs_prev*n_embd; const int32_t n_outputs_new = lctx.n_outputs; @@ -14623,20 +14625,33 @@ float * llama_get_logits(struct llama_context * ctx) { } float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { - const int32_t j = ctx->output_ids[i]; - llama_synchronize(ctx); - if (ctx->logits && 0 <= j && (size_t) j < ctx->output_size) { + try { + if (ctx->logits == nullptr) { + throw std::runtime_error("no logits"); + } + if ((size_t) i >= ctx->output_ids.size()) { + throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); + } + const int32_t j = ctx->output_ids[i]; + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if ((size_t) j >= ctx->output_size) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size)); + } + return ctx->logits + j*ctx->model.hparams.n_vocab; - } - LLAMA_LOG_ERROR("%s: invalid logits id %i, reason: %s (j=%i, output_size=%li)\n", - __func__, i, !ctx->logits ? "no logits" : j < 0 ? "batch.logits[i] wasn't true" : "too big", - j, ctx->output_size); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG - GGML_ASSERT(false); + GGML_ASSERT(false); #endif - return nullptr; + return nullptr; + } } float * llama_get_embeddings(struct llama_context * ctx) { @@ -14646,20 +14661,33 @@ float * llama_get_embeddings(struct llama_context * ctx) { } float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { - const int32_t j = ctx->output_ids[i]; - llama_synchronize(ctx); - if (ctx->embd && 0 <= j && (size_t) j < ctx->output_size) { + try { + if (ctx->embd == nullptr) { + throw std::runtime_error("no embeddings"); + } + if ((size_t) i >= ctx->output_ids.size()) { + throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); + } + const int32_t j = ctx->output_ids[i]; + + if (j < 0) { + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + if ((size_t) j >= ctx->output_size) { + // This should not happen + throw std::runtime_error(format("corrupt output buffer (j=%d, output_size=%lu)", j, ctx->output_size)); + } + return ctx->embd + j*ctx->model.hparams.n_embd; - } - LLAMA_LOG_ERROR("%s: invalid embeddings id %i, reason: %s (j=%i, output_size=%li)\n", - __func__, i, !ctx->embd ? "no embeddings" : j < 0 ? "batch.logits[i] wasn't true" : "too big", - j, ctx->output_size); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG - GGML_ASSERT(false); + GGML_ASSERT(false); #endif - return nullptr; + return nullptr; + } } float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {