From 99291c04f734c5eb42d7987e12dc8f62329d9934 Mon Sep 17 00:00:00 2001 From: Justine Tunney Date: Tue, 21 May 2024 14:13:38 -0700 Subject: [PATCH] Check for llama_get_logits_ith() errors Embeddings models like BERT don't have logits. This caused the llamafile software to crash for users who tried to inference mxbai-embed-large-v1. This change potentially helps prevent the server from crashing. Since it is possible for this function to fail having callers check the result is a good idea from a defensive coding standpoint. The older exception code has also been refactored, since it's no longer needed. --- common/sampling.cpp | 9 ++ examples/batched/batched.cpp | 3 + examples/gritlm/gritlm.cpp | 6 + .../app/src/main/cpp/llama-android.cpp | 3 + examples/passkey/passkey.cpp | 3 + examples/perplexity/perplexity.cpp | 3 + examples/simple/simple.cpp | 3 + llama.cpp | 127 +++++++++--------- 8 files changed, 92 insertions(+), 65 deletions(-) diff --git a/common/sampling.cpp b/common/sampling.cpp index 7fc2e2158..88e5e7b1b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -195,6 +195,9 @@ static llama_token llama_sampling_sample_impl( llama_token id = 0; // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); + if (!logits) { + throw std::runtime_error("llama_get_logits_ith failed"); + } if (temp < 0.0) { // greedy sampling, with probs @@ -284,6 +287,9 @@ static llama_token_data_array llama_sampling_prepare_impl( // Get a pointer to the logits float * logits = llama_get_logits_ith(ctx_main, idx); + if (!logits) { + throw std::runtime_error("llama_get_logits_ith failed"); + } if (ctx_sampling->grammar != NULL && !apply_grammar) { GGML_ASSERT(original_logits != NULL); @@ -298,6 +304,9 @@ static llama_token_data_array llama_sampling_prepare_impl( if (ctx_cfg) { float * logits_guidance = llama_get_logits_ith(ctx_cfg, idx); + if (!logits_guidance) { + throw std::runtime_error("llama_get_logits_ith failed"); + } llama_sample_apply_guidance(ctx_main, logits, logits_guidance, params.cfg_scale); } diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index be30d20bf..9812dcb3d 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -169,6 +169,9 @@ int main(int argc, char ** argv) { auto n_vocab = llama_n_vocab(model); auto * logits = llama_get_logits_ith(ctx, i_batch[i]); + if (!logits) { + return 1; + } std::vector candidates; candidates.reserve(n_vocab); diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 52fd719b3..5a60ca9cf 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -58,6 +58,9 @@ static std::vector> encode(llama_context * ctx, const std::ve // sum up all token embeddings for (int32_t k = n_inst; k < n_toks; k++) { float * emb = llama_get_embeddings_ith(ctx, k); + if (!emb) { + throw std::runtime_error("llama_get_embeddings_ith failed"); + } for (uint64_t j = 0; j < n_embd; j++) { emb_unorm[j] += emb[j]; } @@ -114,6 +117,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_decode(ctx, bat); auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1); + if (!logits) { + throw std::runtime_error("llama_get_logits_ith failed"); + } auto candidates = std::vector(llama_n_vocab(mdl)); auto n_candidates = (int32_t)candidates.size(); diff --git a/examples/llama.android/app/src/main/cpp/llama-android.cpp b/examples/llama.android/app/src/main/cpp/llama-android.cpp index 4af9de303..e20ba7a8e 100644 --- a/examples/llama.android/app/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/app/src/main/cpp/llama-android.cpp @@ -394,6 +394,9 @@ Java_com_example_llama_Llm_completion_1loop( auto n_vocab = llama_n_vocab(model); auto logits = llama_get_logits_ith(context, batch->n_tokens - 1); + if (!logits) { + throw std::runtime_error("llama_get_logits_ith failed"); + } std::vector candidates; candidates.reserve(n_vocab); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index f2ef9ca10..367576016 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -239,6 +239,9 @@ int main(int argc, char ** argv) { { auto n_vocab = llama_n_vocab(model); auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + if (!logits) { + return 1; + } std::vector candidates; candidates.reserve(n_vocab); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index bae014e6f..d5a8d59fc 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -638,6 +638,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par for (int seq = 0; seq < n_seq_batch; seq++) { const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); + if (!all_logits) { + return 1; + } llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; if (!params.logits_file.empty()) { diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index b0f8e0fdc..63ba48c1b 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -120,6 +120,9 @@ int main(int argc, char ** argv) { { auto n_vocab = llama_n_vocab(model); auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); + if (!logits) { + return 1; + } std::vector candidates; candidates.reserve(n_vocab); diff --git a/llama.cpp b/llama.cpp index abff8c1c0..35c3159a9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -17301,42 +17301,39 @@ float * llama_get_logits(struct llama_context * ctx) { return ctx->logits; } +static float * llama_get_logits_ith_fail(int i, std::string reason) { + LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, reason.c_str()); +#ifndef NDEBUG + GGML_ASSERT(false); +#endif + return nullptr; +} + float * llama_get_logits_ith(struct llama_context * ctx, int32_t i) { int32_t j = -1; llama_synchronize(ctx); - - try { - if (ctx->logits == nullptr) { - throw std::runtime_error("no logits"); - } - - if (i < 0) { - j = ctx->n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); - } - } else if ((size_t) i >= ctx->output_ids.size()) { - throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); - } else { - j = ctx->output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= ctx->n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); - } - - return ctx->logits + j*ctx->model.hparams.n_vocab; - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); -#ifndef NDEBUG - GGML_ASSERT(false); -#endif - return nullptr; + if (ctx->logits == nullptr) { + // this can happen for embeddings models like bert + return llama_get_logits_ith_fail(i, "no logits"); } + if (i < 0) { + j = ctx->n_outputs + i; + if (j < 0) { + return llama_get_logits_ith_fail(i, format("negative index out of range [0, %d)", ctx->n_outputs)); + } + } else if ((size_t) i >= ctx->output_ids.size()) { + return llama_get_logits_ith_fail(i, format("out of range [0, %lu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; + } + if (j < 0) { + return llama_get_logits_ith_fail(i, format("batch.logits[%d] != true", i)); + } + if (j >= ctx->n_outputs) { + // This should not happen + return llama_get_logits_ith_fail(i, format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); + } + return ctx->logits + j*ctx->model.hparams.n_vocab; } float * llama_get_embeddings(struct llama_context * ctx) { @@ -17345,43 +17342,43 @@ float * llama_get_embeddings(struct llama_context * ctx) { return ctx->embd; } +static float * llama_get_embeddings_ith_fail(int i, std::string reason) { + LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, reason.c_str()); +#ifndef NDEBUG + GGML_ASSERT(false); +#endif + return nullptr; +} + float * llama_get_embeddings_ith(struct llama_context * ctx, int32_t i) { int32_t j = -1; - llama_synchronize(ctx); - - try { - if (ctx->embd == nullptr) { - throw std::runtime_error("no embeddings"); - } - - if (i < 0) { - j = ctx->n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", ctx->n_outputs)); - } - } else if ((size_t) i >= ctx->output_ids.size()) { - throw std::runtime_error(format("out of range [0, %lu)", ctx->output_ids.size())); - } else { - j = ctx->output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= ctx->n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%d, n_outputs=%d)", j, ctx->n_outputs)); - } - - return ctx->embd + j*ctx->model.hparams.n_embd; - } catch (const std::exception & err) { - LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); -#ifndef NDEBUG - GGML_ASSERT(false); -#endif - return nullptr; + if (ctx->embd == nullptr) { + return llama_get_embeddings_ith_fail(i, "no embeddings"); } + if (i < 0) { + j = ctx->n_outputs + i; + if (j < 0) { + return llama_get_embeddings_ith_fail( + i, format("negative index out of range [0, %d)", ctx->n_outputs)); + } + } else if ((size_t) i >= ctx->output_ids.size()) { + return llama_get_embeddings_ith_fail( + i, format("out of range [0, %lu)", ctx->output_ids.size())); + } else { + j = ctx->output_ids[i]; + } + if (j < 0) { + return llama_get_embeddings_ith_fail( + i, format("batch.logits[%d] != true", i)); + } + if (j >= ctx->n_outputs) { + // This should not happen + return llama_get_embeddings_ith_fail( + i, format("corrupt output buffer (j=%d, n_outputs=%d)", + j, ctx->n_outputs)); + } + return ctx->embd + j*ctx->model.hparams.n_embd; } float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id) {